Updated docstrings using darglint (#2827)

* Updated docstrings using darglint, ignoring 402 and 202 plus shortened lines into multiple where they were overflowing

* Remove abstract method decorators, for a future PR

* Add __future__ import annotation for python 3.7+ notion

* Added missing bracket

* Fix minor docstring tables
This commit is contained in:
Mark Towers
2022-05-25 14:46:41 +01:00
committed by GitHub
parent 4487008ea9
commit 273e3f22ce
37 changed files with 474 additions and 207 deletions

View File

@@ -1,7 +1,6 @@
"""Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper.""" """Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper."""
from __future__ import annotations from __future__ import annotations
from abc import abstractmethod
from typing import Generic, Optional, SupportsFloat, TypeVar, Union from typing import Generic, Optional, SupportsFloat, TypeVar, Union
from gym import spaces from gym import spaces
@@ -63,7 +62,6 @@ class Env(Generic[ObsType, ActType]):
def np_random(self, value: RandomNumberGenerator): def np_random(self, value: RandomNumberGenerator):
self._np_random = value self._np_random = value
@abstractmethod
def step(self, action: ActType) -> tuple[ObsType, float, bool, dict]: def step(self, action: ActType) -> tuple[ObsType, float, bool, dict]:
"""Run one timestep of the environment's dynamics. """Run one timestep of the environment's dynamics.
@@ -71,7 +69,7 @@ class Env(Generic[ObsType, ActType]):
Accepts an action and returns a tuple `(observation, reward, done, info)`. Accepts an action and returns a tuple `(observation, reward, done, info)`.
Args: Args:
action (object): an action provided by the agent action (ActType): an action provided by the agent
Returns: Returns:
observation (object): this will be an element of the environment's :attr:`observation_space`. observation (object): this will be an element of the environment's :attr:`observation_space`.
@@ -88,7 +86,6 @@ class Env(Generic[ObsType, ActType]):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def reset( def reset(
self, self,
*, *,
@@ -129,7 +126,6 @@ class Env(Generic[ObsType, ActType]):
if seed is not None: if seed is not None:
self._np_random, seed = seeding.np_random(seed) self._np_random, seed = seeding.np_random(seed)
@abstractmethod
def render(self, mode="human"): def render(self, mode="human"):
"""Renders the environment. """Renders the environment.
@@ -152,6 +148,7 @@ class Env(Generic[ObsType, ActType]):
in implementations to use the functionality of this method. in implementations to use the functionality of this method.
Example: Example:
>>> import numpy as np
>>> class MyEnv(Env): >>> class MyEnv(Env):
... metadata = {'render_modes': ['human', 'rgb_array']} ... metadata = {'render_modes': ['human', 'rgb_array']}
... ...
@@ -161,7 +158,7 @@ class Env(Generic[ObsType, ActType]):
... elif mode == 'human': ... elif mode == 'human':
... ... # pop up a window and render ... ... # pop up a window and render
... else: ... else:
... super(MyEnv, self).render(mode=mode) # just raise an exception ... super().render(mode=mode) # just raise an exception
Args: Args:
mode: the mode to render with, valid modes are `env.metadata["render_modes"]` mode: the mode to render with, valid modes are `env.metadata["render_modes"]`
@@ -208,7 +205,7 @@ class Env(Generic[ObsType, ActType]):
"""Returns the base non-wrapped environment. """Returns the base non-wrapped environment.
Returns: Returns:
gym.Env: The base non-wrapped gym.Env instance Env: The base non-wrapped gym.Env instance
""" """
return self return self
@@ -389,7 +386,6 @@ class ObservationWrapper(Wrapper):
observation, reward, done, info = self.env.step(action) observation, reward, done, info = self.env.step(action)
return self.observation(observation), reward, done, info return self.observation(observation), reward, done, info
@abstractmethod
def observation(self, observation): def observation(self, observation):
"""Returns a modified observation.""" """Returns a modified observation."""
raise NotImplementedError raise NotImplementedError
@@ -424,7 +420,6 @@ class RewardWrapper(Wrapper):
observation, reward, done, info = self.env.step(action) observation, reward, done, info = self.env.step(action)
return observation, self.reward(reward), done, info return observation, self.reward(reward), done, info
@abstractmethod
def reward(self, reward): def reward(self, reward):
"""Returns a modified ``reward``.""" """Returns a modified ``reward``."""
raise NotImplementedError raise NotImplementedError
@@ -466,12 +461,10 @@ class ActionWrapper(Wrapper):
"""Runs the environment :meth:`env.step` using the modified ``action`` from :meth:`self.action`.""" """Runs the environment :meth:`env.step` using the modified ``action`` from :meth:`self.action`."""
return self.env.step(self.action(action)) return self.env.step(self.action(action))
@abstractmethod
def action(self, action): def action(self, action):
"""Returns a modified action before :meth:`env.step` is called.""" """Returns a modified action before :meth:`env.step` is called."""
raise NotImplementedError raise NotImplementedError
@abstractmethod
def reverse_action(self, action): def reverse_action(self, action):
"""Returns a reversed ``action``.""" """Returns a reversed ``action``."""
raise NotImplementedError raise NotImplementedError

View File

@@ -696,15 +696,16 @@ def heuristic(env, s):
Args: Args:
env: The environment env: The environment
s (list): The state. Attributes: s (list): The state. Attributes:
s[0] is the horizontal coordinate s[0] is the horizontal coordinate
s[1] is the vertical coordinate s[1] is the vertical coordinate
s[2] is the horizontal speed s[2] is the horizontal speed
s[3] is the vertical speed s[3] is the vertical speed
s[4] is the angle s[4] is the angle
s[5] is the angular speed s[5] is the angular speed
s[6] 1 if first leg has contact, else 0 s[6] 1 if first leg has contact, else 0
s[7] 1 if second leg has contact, else 0 s[7] 1 if second leg has contact, else 0
returns:
Returns:
a: The heuristic to be fed into the step function defined above to determine the next step and reward. a: The heuristic to be fed into the step function defined above to determine the next step and reward.
""" """

View File

@@ -44,8 +44,8 @@ class AcrobotEnv(core.Env):
The action is discrete, deterministic, and represents the torque applied on the actuated The action is discrete, deterministic, and represents the torque applied on the actuated
joint between the two links. joint between the two links.
| Num | Action | Unit | | Num | Action | Unit |
|----|-------------------------------------------|---------------| |-----|---------------------------------------|--------------|
| 0 | apply -1 torque to the actuated joint | torque (N m) | | 0 | apply -1 torque to the actuated joint | torque (N m) |
| 1 | apply 0 torque to the actuated joint | torque (N m) | | 1 | apply 0 torque to the actuated joint | torque (N m) |
| 2 | apply 1 torque to the actuated joint | torque (N m) | | 2 | apply 1 torque to the actuated joint | torque (N m) |
@@ -55,27 +55,29 @@ class AcrobotEnv(core.Env):
The observation is a `ndarray` with shape `(6,)` that provides information about the The observation is a `ndarray` with shape `(6,)` that provides information about the
two rotational joint angles as well as their angular velocities: two rotational joint angles as well as their angular velocities:
| Num | Observation | Min | Max | | Num | Observation | Min | Max |
|-----|-----------------------|----------------------|--------------------| |-----|------------------------------|---------------------|-------------------|
| 0 | Cosine of `theta1` | -1 | 1 | | 0 | Cosine of `theta1` | -1 | 1 |
| 1 | Sine of `theta1` | -1 | 1 | | 1 | Sine of `theta1` | -1 | 1 |
| 2 | Cosine of `theta2` | -1 | 1 | | 2 | Cosine of `theta2` | -1 | 1 |
| 3 | Sine of `theta2` | -1 | 1 | | 3 | Sine of `theta2` | -1 | 1 |
| 4 | Angular velocity of `theta1` | ~ -12.567 (-4 * pi) | ~ 12.567 (4 * pi) | | 4 | Angular velocity of `theta1` | ~ -12.567 (-4 * pi) | ~ 12.567 (4 * pi) |
| 5 | Angular velocity of `theta2` | ~ -28.274 (-9 * pi) | ~ 28.274 (9 * pi) | | 5 | Angular velocity of `theta2` | ~ -28.274 (-9 * pi) | ~ 28.274 (9 * pi) |
where where
- `theta1` is the angle of the first joint, where an angle of 0 indicates the first link is pointing directly - `theta1` is the angle of the first joint, where an angle of 0 indicates the first link is pointing directly
downwards. downwards.
- `theta2` is ***relative to the angle of the first link.*** An angle of 0 corresponds to having the same angle between the - `theta2` is ***relative to the angle of the first link.***
two links. An angle of 0 corresponds to having the same angle between the two links.
The angular velocities of `theta1` and `theta2` are bounded at ±4π, and ±9π rad/s respectively. The angular velocities of `theta1` and `theta2` are bounded at ±4π, and ±9π rad/s respectively.
A state of `[1, 0, 1, 0, ..., ...]` indicates that both links are pointing downwards. A state of `[1, 0, 1, 0, ..., ...]` indicates that both links are pointing downwards.
### Rewards ### Rewards
The goal is to have the free end reach a designated target height in as few steps as possible, and as such all steps that do not reach the goal incur a reward of -1. Achieving the target height results in termination with a reward of 0. The reward threshold is -100. The goal is to have the free end reach a designated target height in as few steps as possible,
and as such all steps that do not reach the goal incur a reward of -1.
Achieving the target height results in termination with a reward of 0. The reward threshold is -100.
### Starting State ### Starting State
@@ -98,7 +100,8 @@ class AcrobotEnv(core.Env):
``` ```
By default, the dynamics of the acrobot follow those described in Sutton and Barto's book By default, the dynamics of the acrobot follow those described in Sutton and Barto's book
[Reinforcement Learning: An Introduction](http://incompleteideas.net/book/11/node4.html). However, a `book_or_nips` parameter can be modified to change the pendulum dynamics to those described [Reinforcement Learning: An Introduction](http://incompleteideas.net/book/11/node4.html).
However, a `book_or_nips` parameter can be modified to change the pendulum dynamics to those described
in the original [NeurIPS paper](https://papers.nips.cc/paper/1995/hash/8f1d43620bc6bb580df6e80b0dc05c48-Abstract.html). in the original [NeurIPS paper](https://papers.nips.cc/paper/1995/hash/8f1d43620bc6bb580df6e80b0dc05c48-Abstract.html).
``` ```
@@ -125,7 +128,9 @@ class AcrobotEnv(core.Env):
- v0: Initial versions release (1.0.0) (removed from gym for v1) - v0: Initial versions release (1.0.0) (removed from gym for v1)
### References ### References
- Sutton, R. S. (1996). Generalization in Reinforcement Learning: Successful Examples Using Sparse Coarse Coding. In D. Touretzky, M. C. Mozer, & M. Hasselmo (Eds.), Advances in Neural Information Processing Systems (Vol. 8). MIT Press. https://proceedings.neurips.cc/paper/1995/file/8f1d43620bc6bb580df6e80b0dc05c48-Paper.pdf - Sutton, R. S. (1996). Generalization in Reinforcement Learning: Successful Examples Using Sparse Coarse Coding.
In D. Touretzky, M. C. Mozer, & M. Hasselmo (Eds.), Advances in Neural Information Processing Systems (Vol. 8).
MIT Press. https://proceedings.neurips.cc/paper/1995/file/8f1d43620bc6bb580df6e80b0dc05c48-Paper.pdf
- Sutton, R. S., Barto, A. G. (2018 ). Reinforcement Learning: An Introduction. The MIT Press. - Sutton, R. S., Barto, A. G. (2018 ). Reinforcement Learning: An Introduction. The MIT Press.
""" """
@@ -380,6 +385,8 @@ def bound(x, m, M=None):
Args: Args:
x: scalar x: scalar
m: The lower bound
M: The upper bound
Returns: Returns:
x: scalar, bound between min (m) and Max (M) x: scalar, bound between min (m) and Max (M)
@@ -398,15 +405,15 @@ def rk4(derivs, y0, t):
yourself stranded on a system w/o scipy. Otherwise use yourself stranded on a system w/o scipy. Otherwise use
:func:`scipy.integrate`. :func:`scipy.integrate`.
Example: Example for 2D system:
>>> ### 2D system
>>> def derivs(x): >>> def derivs(x):
... d1 = x[0] + 2*x[1] ... d1 = x[0] + 2*x[1]
... d2 = -3*x[0] + 4*x[1] ... d2 = -3*x[0] + 4*x[1]
... return (d1, d2) ... return d1, d2
>>> dt = 0.0005 >>> dt = 0.0005
>>> t = arange(0.0, 2.0, dt) >>> t = np.arange(0.0, 2.0, dt)
>>> y0 = (1,2) >>> y0 = (1,2)
>>> yout = rk4(derivs, y0, t) >>> yout = rk4(derivs, y0, t)

View File

@@ -17,40 +17,47 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
""" """
### Description ### Description
This environment corresponds to the version of the cart-pole problem This environment corresponds to the version of the cart-pole problem described by Barto, Sutton, and Anderson in
described by Barto, Sutton, and Anderson in ["Neuronlike Adaptive Elements That Can Solve Difficult Learning Control Problem"](https://ieeexplore.ieee.org/document/6313077). ["Neuronlike Adaptive Elements That Can Solve Difficult Learning Control Problem"](https://ieeexplore.ieee.org/document/6313077).
A pole is attached by an un-actuated joint to a cart, which moves along a A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track.
frictionless track. The pendulum is placed upright on the cart and the goal is to balance the pole by applying forces in the left and right direction on the cart. The pendulum is placed upright on the cart and the goal is to balance the pole by applying forces
in the left and right direction on the cart.
### Action Space ### Action Space
The action is a `ndarray` with shape `(1,)` which can take values `{0, 1}` indicating the direction of the fixed force the cart is pushed with. The action is a `ndarray` with shape `(1,)` which can take values `{0, 1}` indicating the direction
of the fixed force the cart is pushed with.
| Num | Action | | Num | Action |
|-----|------------------------| |-----|------------------------|
| 0 | Push cart to the left | | 0 | Push cart to the left |
| 1 | Push cart to the right | | 1 | Push cart to the right |
**Note**: The velocity that is reduced or increased by the applied force is not fixed and it depends on the angle the pole is pointing. The center of gravity of the pole varies the amount of energy needed to move the cart underneath it **Note**: The velocity that is reduced or increased by the applied force is not fixed and it depends on the angle
the pole is pointing. The center of gravity of the pole varies the amount of energy needed to move the cart underneath it
### Observation Space ### Observation Space
The observation is a `ndarray` with shape `(4,)` with the values corresponding to the following positions and velocities: The observation is a `ndarray` with shape `(4,)` with the values corresponding to the following positions and velocities:
| Num | Observation | Min | Max | | Num | Observation | Min | Max |
|-----|-----------------------|----------------------|--------------------| |-----|-----------------------|---------------------|-------------------|
| 0 | Cart Position | -4.8 | 4.8 | | 0 | Cart Position | -4.8 | 4.8 |
| 1 | Cart Velocity | -Inf | Inf | | 1 | Cart Velocity | -Inf | Inf |
| 2 | Pole Angle | ~ -0.418 rad (-24°) | ~ 0.418 rad (24°) | | 2 | Pole Angle | ~ -0.418 rad (-24°) | ~ 0.418 rad (24°) |
| 3 | Pole Angular Velocity | -Inf | Inf | | 3 | Pole Angular Velocity | -Inf | Inf |
**Note:** While the ranges above denote the possible values for observation space of each element, it is not reflective of the allowed values of the state space in an unterminated episode. Particularly: **Note:** While the ranges above denote the possible values for observation space of each element,
- The cart x-position (index 0) can be take values between `(-4.8, 4.8)`, but the episode terminates if the cart leaves the `(-2.4, 2.4)` range. it is not reflective of the allowed values of the state space in an unterminated episode. Particularly:
- The pole angle can be observed between `(-.418, .418)` radians (or **±24°**), but the episode terminates if the pole angle is not in the range `(-.2095, .2095)` (or **±12°**) - The cart x-position (index 0) can be take values between `(-4.8, 4.8)`, but the episode terminates
if the cart leaves the `(-2.4, 2.4)` range.
- The pole angle can be observed between `(-.418, .418)` radians (or **±24°**), but the episode terminates
if the pole angle is not in the range `(-.2095, .2095)` (or **±12°**)
### Rewards ### Rewards
Since the goal is to keep the pole upright for as long as possible, a reward of `+1` for every step taken, including the termination step, is allotted. The threshold for rewards is 475 for v1. Since the goal is to keep the pole upright for as long as possible, a reward of `+1` for every step taken,
including the termination step, is allotted. The threshold for rewards is 475 for v1.
### Starting State ### Starting State

View File

@@ -49,14 +49,15 @@ class Continuous_MountainCarEnv(gym.Env):
The observation is a `ndarray` with shape `(2,)` where the elements correspond to the following: The observation is a `ndarray` with shape `(2,)` where the elements correspond to the following:
| Num | Observation | Min | Max | Unit | | Num | Observation | Min | Max | Unit |
|-----|-------------------------------------------------------------|--------------------|--------|------| |-----|--------------------------------------|------|-----|--------------|
| 0 | position of the car along the x-axis | -Inf | Inf | position (m) | | 0 | position of the car along the x-axis | -Inf | Inf | position (m) |
| 1 | velocity of the car | -Inf | Inf | position (m) | | 1 | velocity of the car | -Inf | Inf | position (m) |
### Action Space ### Action Space
The action is a `ndarray` with shape `(1,)`, representing the directional force applied on the car. The action is clipped in the range `[-1,1]` and multiplied by a power of 0.0015. The action is a `ndarray` with shape `(1,)`, representing the directional force applied on the car.
The action is clipped in the range `[-1,1]` and multiplied by a power of 0.0015.
### Transition Dynamics: ### Transition Dynamics:
@@ -66,15 +67,20 @@ class Continuous_MountainCarEnv(gym.Env):
*position<sub>t+1</sub> = position<sub>t</sub> + velocity<sub>t+1</sub>* *position<sub>t+1</sub> = position<sub>t</sub> + velocity<sub>t+1</sub>*
where force is the action clipped to the range `[-1,1]` and power is a constant 0.0015. The collisions at either end are inelastic with the velocity set to 0 upon collision with the wall. The position is clipped to the range [-1.2, 0.6] and velocity is clipped to the range [-0.07, 0.07]. where force is the action clipped to the range `[-1,1]` and power is a constant 0.0015.
The collisions at either end are inelastic with the velocity set to 0 upon collision with the wall.
The position is clipped to the range [-1.2, 0.6] and velocity is clipped to the range [-0.07, 0.07].
### Reward ### Reward
A negative reward of *-0.1 * action<sup>2</sup>* is received at each timestep to penalise for taking actions of large magnitude. If the mountain car reaches the goal then a positive reward of +100 is added to the negative reward for that timestep. A negative reward of *-0.1 * action<sup>2</sup>* is received at each timestep to penalise for
taking actions of large magnitude. If the mountain car reaches the goal then a positive reward of +100
is added to the negative reward for that timestep.
### Starting State ### Starting State
The position of the car is assigned a uniform random value in `[-0.6 , -0.4]`. The starting velocity of the car is always assigned to 0. The position of the car is assigned a uniform random value in `[-0.6 , -0.4]`.
The starting velocity of the car is always assigned to 0.
### Episode Termination ### Episode Termination

View File

@@ -38,20 +38,20 @@ class MountainCarEnv(gym.Env):
The observation is a `ndarray` with shape `(2,)` where the elements correspond to the following: The observation is a `ndarray` with shape `(2,)` where the elements correspond to the following:
| Num | Observation | Min | Max | Unit | | Num | Observation | Min | Max | Unit |
|-----|-------------------------------------------------------------|--------------------|--------|------| |-----|--------------------------------------|------|-----|--------------|
| 0 | position of the car along the x-axis | -Inf | Inf | position (m) | | 0 | position of the car along the x-axis | -Inf | Inf | position (m) |
| 1 | velocity of the car | -Inf | Inf | position (m) | | 1 | velocity of the car | -Inf | Inf | position (m) |
### Action Space ### Action Space
There are 3 discrete deterministic actions: There are 3 discrete deterministic actions:
| Num | Observation | Value | Unit | | Num | Observation | Value | Unit |
|-----|-------------------------------------------------------------|---------|------| |-----|-------------------------|------ |--------------|
| 0 | Accelerate to the left | Inf | position (m) | | 0 | Accelerate to the left | Inf | position (m) |
| 1 | Don't accelerate | Inf | position (m) | | 1 | Don't accelerate | Inf | position (m) |
| 2 | Accelerate to the right | Inf | position (m) | | 2 | Accelerate to the right | Inf | position (m) |
### Transition Dynamics: ### Transition Dynamics:
@@ -61,16 +61,21 @@ class MountainCarEnv(gym.Env):
*position<sub>t+1</sub> = position<sub>t</sub> + velocity<sub>t+1</sub>* *position<sub>t+1</sub> = position<sub>t</sub> + velocity<sub>t+1</sub>*
where force = 0.001 and gravity = 0.0025. The collisions at either end are inelastic with the velocity set to 0 upon collision with the wall. The position is clipped to the range `[-1.2, 0.6]` and velocity is clipped to the range `[-0.07, 0.07]`. where force = 0.001 and gravity = 0.0025. The collisions at either end are inelastic with the velocity set to 0
upon collision with the wall. The position is clipped to the range `[-1.2, 0.6]` and
velocity is clipped to the range `[-0.07, 0.07]`.
### Reward: ### Reward:
The goal is to reach the flag placed on top of the right hill as quickly as possible, as such the agent is penalised with a reward of -1 for each timestep it isn't at the goal and is not penalised (reward = 0) for when it reaches the goal. The goal is to reach the flag placed on top of the right hill as quickly as possible, as such the agent is
penalised with a reward of -1 for each timestep it isn't at the goal and is not penalised (reward = 0) for
when it reaches the goal.
### Starting State ### Starting State
The position of the car is assigned a uniform random value in *[-0.6 , -0.4]*. The starting velocity of the car is always assigned to 0. The position of the car is assigned a uniform random value in *[-0.6 , -0.4]*.
The starting velocity of the car is always assigned to 0.
### Episode Termination ### Episode Termination

View File

@@ -14,7 +14,10 @@ class PendulumEnv(gym.Env):
""" """
### Description ### Description
The inverted pendulum swingup problem is based on the classic problem in control theory. The system consists of a pendulum attached at one end to a fixed point, and the other end being free. The pendulum starts in a random position and the goal is to apply torque on the free end to swing it into an upright position, with its center of gravity right above the fixed point. The inverted pendulum swingup problem is based on the classic problem in control theory.
The system consists of a pendulum attached at one end to a fixed point, and the other end being free.
The pendulum starts in a random position and the goal is to apply torque on the free end to swing it
into an upright position, with its center of gravity right above the fixed point.
The diagram below specifies the coordinate system used for the implementation of the pendulum's The diagram below specifies the coordinate system used for the implementation of the pendulum's
dynamic equations. dynamic equations.
@@ -36,7 +39,8 @@ class PendulumEnv(gym.Env):
### Observation Space ### Observation Space
The observation is a `ndarray` with shape `(3,)` representing the x-y coordinates of the pendulum's free end and its angular velocity. The observation is a `ndarray` with shape `(3,)` representing the x-y coordinates of the pendulum's free
end and its angular velocity.
| Num | Observation | Min | Max | | Num | Observation | Min | Max |
|-----|------------------|------|-----| |-----|------------------|------|-----|
@@ -51,8 +55,9 @@ class PendulumEnv(gym.Env):
*r = -(theta<sup>2</sup> + 0.1 * theta_dt<sup>2</sup> + 0.001 * torque<sup>2</sup>)* *r = -(theta<sup>2</sup> + 0.1 * theta_dt<sup>2</sup> + 0.001 * torque<sup>2</sup>)*
where `$\theta$` is the pendulum's angle normalized between *[-pi, pi]* (with 0 being in the upright position). where `$\theta$` is the pendulum's angle normalized between *[-pi, pi]* (with 0 being in the upright position).
Based on the above equation, the minimum reward that can be obtained is *-(pi<sup>2</sup> + 0.1 * 8<sup>2</sup> + 0.001 * 2<sup>2</sup>) = -16.2736044*, while the maximum reward is zero (pendulum is Based on the above equation, the minimum reward that can be obtained is
upright with zero velocity and no torque applied). *-(pi<sup>2</sup> + 0.1 * 8<sup>2</sup> + 0.001 * 2<sup>2</sup>) = -16.2736044*,
while the maximum reward is zero (pendulum is upright with zero velocity and no torque applied).
### Starting State ### Starting State
@@ -64,7 +69,8 @@ class PendulumEnv(gym.Env):
### Arguments ### Arguments
- `g`: acceleration of gravity measured in *(m s<sup>-2</sup>)* used to calculate the pendulum dynamics. The default value is g = 10.0 . - `g`: acceleration of gravity measured in *(m s<sup>-2</sup>)* used to calculate the pendulum dynamics.
The default value is g = 10.0 .
``` ```
gym.make('Pendulum-v1', g=9.81) gym.make('Pendulum-v1', g=9.81)

View File

@@ -122,10 +122,8 @@ class MujocoEnv(gym.Env):
def viewer_setup(self): def viewer_setup(self):
""" """
This method is called when the viewer is initialized. This method is called when the viewer is initialized.
Optionally implement this method, if you need to tinker with camera position Optionally implement this method, if you need to tinker with camera position and so forth.
and so forth.
""" """
pass
# ----------------------------- # -----------------------------

View File

@@ -14,15 +14,15 @@ class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
### Action Space ### Action Space
The action space is a `Box(-2, 2, (7,), float32)`. An action `(a, b)` represents the torques applied at the hinge joints. The action space is a `Box(-2, 2, (7,), float32)`. An action `(a, b)` represents the torques applied at the hinge joints.
| Num | Action | Control Min | Control Max | Name (in corresponding XML file) | Joint | Unit | | Num | Action | Control Min | Control Max | Name (in corresponding XML file) | Joint | Unit |
|-----|--------------------------------------------------------------------|-------------|-------------|----------------------------------|-------|--------------| |-----|------------------------------------------------|-------------|-------------|----------------------------------|-------|--------------|
| 0 | Rotation of the panning the shoulder | -2 | 2 | r_shoulder_pan_joint | hinge | torque (N m) | | 0 | Rotation of the panning the shoulder | -2 | 2 | r_shoulder_pan_joint | hinge | torque (N m) |
| 1 | Rotation of the shoulder lifting joint | -2 | 2 | r_shoulder_lift_joint | hinge | torque (N m) | | 1 | Rotation of the shoulder lifting joint | -2 | 2 | r_shoulder_lift_joint | hinge | torque (N m) |
| 2 | Rotation of the shoulder rolling joint | -2 | 2 | r_upper_arm_roll_joint | hinge | torque (N m) | | 2 | Rotation of the shoulder rolling joint | -2 | 2 | r_upper_arm_roll_joint | hinge | torque (N m) |
| 3 | Rotation of hinge joint that flexed the elbow | -2 | 2 | r_elbow_flex_joint | hinge | torque (N m) | | 3 | Rotation of hinge joint that flexed the elbow | -2 | 2 | r_elbow_flex_joint | hinge | torque (N m) |
| 4 | Rotation of hinge that rolls the forearm | -2 | 2 | r_forearm_roll_joint | hinge | torque (N m) | | 4 | Rotation of hinge that rolls the forearm | -2 | 2 | r_forearm_roll_joint | hinge | torque (N m) |
| 5 | Rotation of flexing the wrist | -2 | 2 | r_wrist_flex_joint | hinge | torque (N m) | | 5 | Rotation of flexing the wrist | -2 | 2 | r_wrist_flex_joint | hinge | torque (N m) |
| 6 | Rotation of rolling the wrist | -2 | 2 | r_wrist_roll_joint | hinge | torque (N m) | | 6 | Rotation of rolling the wrist | -2 | 2 | r_wrist_roll_joint | hinge | torque (N m) |
### Observation Space ### Observation Space

View File

@@ -30,19 +30,19 @@ class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
The observation is a `ndarray` with shape `(11,)` where the elements correspond to the following: The observation is a `ndarray` with shape `(11,)` where the elements correspond to the following:
| Num | Observation | Min | Max | Name (in corresponding XML file) | Joint | Unit | | Num | Observation | Min | Max | Name (in corresponding XML file) | Joint | Unit |
|-----|-------------------------------------------------------------------------------------------------|------|-----|----------------------------------|-------|--------------------------| |-----|------------------------------------------------------------------------------------------------|------|-----|----------------------------------|-------|--------------------------|
| 0 | cosine of the angle of the first arm | -Inf | Inf | cos(joint0) | hinge | unitless | | 0 | cosine of the angle of the first arm | -Inf | Inf | cos(joint0) | hinge | unitless |
| 1 | cosine of the angle of the second arm | -Inf | Inf | cos(joint1) | hinge | unitless | | 1 | cosine of the angle of the second arm | -Inf | Inf | cos(joint1) | hinge | unitless |
| 2 | sine of the angle of the first arm | -Inf | Inf | cos(joint0) | hinge | unitless | | 2 | sine of the angle of the first arm | -Inf | Inf | cos(joint0) | hinge | unitless |
| 3 | sine of the angle of the second arm | -Inf | Inf | cos(joint1) | hinge | unitless | | 3 | sine of the angle of the second arm | -Inf | Inf | cos(joint1) | hinge | unitless |
| 4 | x-coordinate of the target | -Inf | Inf | target_x | slide | position (m) | | 4 | x-coordinate of the target | -Inf | Inf | target_x | slide | position (m) |
| 5 | y-coordinate of the target | -Inf | Inf | target_y | slide | position (m) | | 5 | y-coordinate of the target | -Inf | Inf | target_y | slide | position (m) |
| 6 | angular velocity of the first arm | -Inf | Inf | joint0 | hinge | angular velocity (rad/s) | | 6 | angular velocity of the first arm | -Inf | Inf | joint0 | hinge | angular velocity (rad/s) |
| 7 | angular velocity of the second arm | -Inf | Inf | joint1 | hinge | angular velocity (rad/s) | | 7 | angular velocity of the second arm | -Inf | Inf | joint1 | hinge | angular velocity (rad/s) |
| 8 | x-value of position_fingertip - position_target | -Inf | Inf | NA | slide | position (m) | | 8 | x-value of position_fingertip - position_target | -Inf | Inf | NA | slide | position (m) |
| 9 | y-value of position_fingertip - position_target | -Inf | Inf | NA | slide | position (m) | | 9 | y-value of position_fingertip - position_target | -Inf | Inf | NA | slide | position (m) |
| 10 | z-value of position_fingertip - position_target (0 since reacher is 2d and z is same for both) | -Inf | Inf | NA | slide | position (m) | | 10 | z-value of position_fingertip - position_target (0 since reacher is 2d and z is same for both) | -Inf | Inf | NA | slide | position (m) |
Most Gym environments just return the positions and velocity of the Most Gym environments just return the positions and velocity of the

View File

@@ -16,8 +16,6 @@ from typing import (
Optional, Optional,
Sequence, Sequence,
SupportsFloat, SupportsFloat,
Tuple,
Type,
Union, Union,
overload, overload,
) )
@@ -49,14 +47,14 @@ ENV_ID_RE: re.Pattern = re.compile(
) )
def load(name: str) -> Type: def load(name: str) -> type:
mod_name, attr_name = name.split(":") mod_name, attr_name = name.split(":")
mod = importlib.import_module(mod_name) mod = importlib.import_module(mod_name)
fn = getattr(mod, attr_name) fn = getattr(mod, attr_name)
return fn return fn
def parse_env_id(id: str) -> Tuple[Optional[str], str, Optional[int]]: def parse_env_id(id: str) -> tuple[Optional[str], str, Optional[int]]:
"""Parse environment ID string format. """Parse environment ID string format.
This format is true today, but it's *not* an official spec. This format is true today, but it's *not* an official spec.
@@ -64,6 +62,15 @@ def parse_env_id(id: str) -> Tuple[Optional[str], str, Optional[int]]:
2016-10-31: We're experimentally expanding the environment ID format 2016-10-31: We're experimentally expanding the environment ID format
to include an optional namespace. to include an optional namespace.
Args:
id: The environment id to parse
Returns:
A tuple of environment namespace, environment name and version number
Raises:
Error: If the environment id does not a valid environment regex
""" """
match = ENV_ID_RE.fullmatch(id) match = ENV_ID_RE.fullmatch(id)
if not match: if not match:
@@ -78,9 +85,17 @@ def parse_env_id(id: str) -> Tuple[Optional[str], str, Optional[int]]:
return namespace, name, version return namespace, name, version
def get_env_id(ns: Optional[str], name: str, version: Optional[int]): def get_env_id(ns: Optional[str], name: str, version: Optional[int]) -> str:
"""Get the full env ID given a name and (optional) version and namespace. """Get the full env ID given a name and (optional) version and namespace. Inverse of :meth:`parse_env_id`.
Inverse of parse_env_id."""
Args:
ns: The environment namespace
name: The environment name
version: The environment version
Returns:
The environment id
"""
full_name = name full_name = name
if version is not None: if version is not None:
@@ -172,7 +187,18 @@ def _check_name_exists(ns: Optional[str], name: str):
def _check_version_exists(ns: Optional[str], name: str, version: Optional[int]): def _check_version_exists(ns: Optional[str], name: str, version: Optional[int]):
"""Check if an env version exists in a namespace. If it doesn't, print a helpful error message. """Check if an env version exists in a namespace. If it doesn't, print a helpful error message.
This is a complete test whether an environment identifier is valid, and will provide the best available hints.""" This is a complete test whether an environment identifier is valid, and will provide the best available hints.
Args:
ns: The environment namespace
name: The environment space
version: The environment version
Raises:
DeprecatedEnv: The environment doesn't exist but a default version does
VersionNotFound: The ``version`` used doesn't exist
DeprecatedEnv: Environment version is deprecated
"""
if get_env_id(ns, name, version) in registry: if get_env_id(ns, name, version) in registry:
return return
@@ -344,6 +370,7 @@ class EnvRegistry(dict):
Turns out that some existing code directly used the old `EnvRegistry` code, Turns out that some existing code directly used the old `EnvRegistry` code,
even though the intended API was just `register` and `make`. even though the intended API was just `register` and `make`.
This reimplements some old methods, so that e.g. pybullet environments will still work. This reimplements some old methods, so that e.g. pybullet environments will still work.
Ideally, nobody should ever use these methods, and they will be removed soon. Ideally, nobody should ever use these methods, and they will be removed soon.
""" """
@@ -458,13 +485,16 @@ def namespace(ns: str):
def register(id: str, **kwargs): def register(id: str, **kwargs):
""" """Register an environment with gym.
Register an environment with gym. The `id` parameter corresponds to the name of the environment,
with the syntax as follows: The `id` parameter corresponds to the name of the environment, with the syntax as follows:
`(namespace)/(env_name)-v(version)` `(namespace)/(env_name)-v(version)` where `namespace` is optional.
where `namespace` is optional.
It takes arbitrary keyword arguments, which are passed to the `EnvSpec` constructor. It takes arbitrary keyword arguments, which are passed to the `EnvSpec` constructor.
Args:
id: The environment id
**kwargs: arbitrary keyword arguments which are passed to the environment constructor
""" """
global registry, current_namespace global registry, current_namespace
ns, name, version = parse_env_id(id) ns, name, version = parse_env_id(id)
@@ -498,8 +528,7 @@ def make(
disable_env_checker: bool = False, disable_env_checker: bool = False,
**kwargs, **kwargs,
) -> Env: ) -> Env:
""" """Create an environment according to the given ID.
Create an environment according to the given ID.
Warnings: Warnings:
In v0.24, `gym.utils.env_checker.env_checker` is run for every initialised environment. In v0.24, `gym.utils.env_checker.env_checker` is run for every initialised environment.
@@ -512,8 +541,12 @@ def make(
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper). autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
disable_env_checker: If to disable the environment checker disable_env_checker: If to disable the environment checker
kwargs: Additional arguments to pass to the environment constructor. kwargs: Additional arguments to pass to the environment constructor.
Returns: Returns:
An instance of the environment. An instance of the environment.
Raises:
Error: If the ``id`` doesn't exist then an error is raised
""" """
if isinstance(id, EnvSpec): if isinstance(id, EnvSpec):
spec_ = id spec_ = id
@@ -588,7 +621,8 @@ def make(
check_env(env) check_env(env)
except Exception as e: except Exception as e:
logger.warn( logger.warn(
f"Env check failed with the following message: {e}\nYou can call `gym.make(..., disable_env_checker=True)` to disable this check." f"Env check failed with the following message: {e}\n"
f"You can set `disable_env_checker=True` to disable this check."
) )
return env return env

View File

@@ -42,8 +42,10 @@ class CliffWalkingEnv(Env):
- 3: move left - 3: move left
### Observations ### Observations
There are 3x12 + 1 possible states. In fact, the agent cannot be at the cliff, nor at the goal (as this results the end of episode). They remain all the positions of the first 3 rows plus the bottom-left cell. There are 3x12 + 1 possible states. In fact, the agent cannot be at the cliff, nor at the goal
The observation is simply the current position encoded as [flattened index](https://numpy.org/doc/stable/reference/generated/numpy.unravel_index.html). (as this results the end of episode). They remain all the positions of the first 3 rows plus the bottom-left cell.
The observation is simply the current position encoded as
[flattened index](https://numpy.org/doc/stable/reference/generated/numpy.unravel_index.html).
### Reward ### Reward
Each time step incurs -1 reward, and stepping into the cliff incurs -100 reward. Each time step incurs -1 reward, and stepping into the cliff incurs -100 reward.
@@ -89,12 +91,8 @@ class CliffWalkingEnv(Env):
self.observation_space = spaces.Discrete(self.nS) self.observation_space = spaces.Discrete(self.nS)
self.action_space = spaces.Discrete(self.nA) self.action_space = spaces.Discrete(self.nA)
def _limit_coordinates(self, coord): def _limit_coordinates(self, coord: np.ndarray) -> np.ndarray:
""" """Prevent the agent from falling out of the grid world."""
Prevent the agent from falling out of the grid world
:param coord:
:return:
"""
coord[0] = min(coord[0], self.shape[0] - 1) coord[0] = min(coord[0], self.shape[0] - 1)
coord[0] = max(coord[0], 0) coord[0] = max(coord[0], 0)
coord[1] = min(coord[1], self.shape[1] - 1) coord[1] = min(coord[1], self.shape[1] - 1)
@@ -102,11 +100,14 @@ class CliffWalkingEnv(Env):
return coord return coord
def _calculate_transition_prob(self, current, delta): def _calculate_transition_prob(self, current, delta):
""" """Determine the outcome for an action. Transition Prob is always 1.0.
Determine the outcome for an action. Transition Prob is always 1.0.
:param current: Current position on the grid as (row, col) Args:
:param delta: Change in position for transition current: Current position on the grid as (row, col)
:return: (1.0, new_state, reward, done) delta: Change in position for transition
Returns:
Tuple of ``(1.0, new_state, reward, done)``
""" """
new_position = np.array(current) + np.array(delta) new_position = np.array(current) + np.array(delta)
new_position = self._limit_coordinates(new_position).astype(int) new_position = self._limit_coordinates(new_position).astype(int)

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
from contextlib import closing from contextlib import closing
from io import StringIO from io import StringIO
from os import path from os import path
@@ -29,10 +31,15 @@ MAPS = {
} }
def generate_random_map(size=8, p=0.8): def generate_random_map(size: int = 8, p: float = 0.8) -> list[str]:
"""Generates a random valid map (one that has a path from start to goal) """Generates a random valid map (one that has a path from start to goal)
:param size: size of each side of the grid
:param p: probability that a tile is frozen Args:
size: size of each side of the grid
p: probability that a tile is frozen
Returns:
A random valid map
""" """
valid = False valid = False
@@ -67,8 +74,9 @@ def generate_random_map(size=8, p=0.8):
class FrozenLakeEnv(Env): class FrozenLakeEnv(Env):
""" """
Frozen lake involves crossing a frozen lake from Start(S) to Goal(G) without falling into any Holes(H) by walking over Frozen lake involves crossing a frozen lake from Start(S) to Goal(G) without falling into any Holes(H)
the Frozen(F) lake. The agent may not always move in the intended direction due to the slippery nature of the frozen lake. by walking over the Frozen(F) lake.
The agent may not always move in the intended direction due to the slippery nature of the frozen lake.
### Action Space ### Action Space

View File

@@ -1,11 +1,10 @@
import numpy as np import numpy as np
from gym.utils import seeding
def categorical_sample(prob_n, np_random):
""" def categorical_sample(prob_n, np_random: seeding.RandomNumberGenerator):
Sample from categorical distribution """Sample from categorical distribution where each row specifies class probabilities."""
Each row specifies class probabilities
"""
prob_n = np.asarray(prob_n) prob_n = np.asarray(prob_n)
csprob_n = np.cumsum(prob_n) csprob_n = np.cumsum(prob_n)
return (csprob_n > np_random.random()).argmax() return np.argmax(csprob_n > np_random.random())

View File

@@ -1,7 +1,7 @@
"""Implementation of a space that represents closed boxes in euclidean space.""" """Implementation of a space that represents closed boxes in euclidean space."""
from __future__ import annotations from __future__ import annotations
from typing import Optional, Sequence, SupportsFloat, Tuple, Type, Union from typing import Optional, Sequence, SupportsFloat, Union
import numpy as np import numpy as np
@@ -15,6 +15,12 @@ def _short_repr(arr: np.ndarray) -> str:
If arr is a multiple of the all-ones vector, return a string representation of the multiplier. If arr is a multiple of the all-ones vector, return a string representation of the multiplier.
Otherwise, return a string representation of the entire array. Otherwise, return a string representation of the entire array.
Args:
arr: The array to represent
Returns:
A short representation of the array
""" """
if arr.size != 0 and np.min(arr) == np.max(arr): if arr.size != 0 and np.min(arr) == np.max(arr):
return str(np.min(arr)) return str(np.min(arr))
@@ -46,7 +52,7 @@ class Box(Space[np.ndarray]):
low: Union[SupportsFloat, np.ndarray], low: Union[SupportsFloat, np.ndarray],
high: Union[SupportsFloat, np.ndarray], high: Union[SupportsFloat, np.ndarray],
shape: Optional[Sequence[int]] = None, shape: Optional[Sequence[int]] = None,
dtype: Type = np.float32, dtype: type = np.float32,
seed: Optional[int | seeding.RandomNumberGenerator] = None, seed: Optional[int | seeding.RandomNumberGenerator] = None,
): ):
r"""Constructor of :class:`Box`. r"""Constructor of :class:`Box`.
@@ -57,7 +63,6 @@ class Box(Space[np.ndarray]):
If ``low`` (or ``high``) is a scalar, the lower bound (or upper bound, respectively) will be assumed to be If ``low`` (or ``high``) is a scalar, the lower bound (or upper bound, respectively) will be assumed to be
this value across all dimensions. this value across all dimensions.
Args: Args:
low (Union[SupportsFloat, np.ndarray]): Lower bounds of the intervals. low (Union[SupportsFloat, np.ndarray]): Lower bounds of the intervals.
high (Union[SupportsFloat, np.ndarray]): Upper bounds of the intervals. high (Union[SupportsFloat, np.ndarray]): Upper bounds of the intervals.
@@ -65,6 +70,10 @@ class Box(Space[np.ndarray]):
Otherwise, the shape is inferred from the shape of ``low`` or ``high``. Otherwise, the shape is inferred from the shape of ``low`` or ``high``.
dtype: The dtype of the elements of the space. If this is an integer type, the :class:`Box` is essentially a discrete space. dtype: The dtype of the elements of the space. If this is an integer type, the :class:`Box` is essentially a discrete space.
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space. seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
Raises:
ValueError: If no shape information is provided (shape is None, low is None and high is None) then a
value error is raised.
""" """
assert dtype is not None, "dtype must be explicitly provided. " assert dtype is not None, "dtype must be explicitly provided. "
self.dtype = np.dtype(dtype) self.dtype = np.dtype(dtype)
@@ -96,7 +105,7 @@ class Box(Space[np.ndarray]):
assert isinstance(high, np.ndarray) assert isinstance(high, np.ndarray)
assert high.shape == shape, "high.shape doesn't match provided shape" assert high.shape == shape, "high.shape doesn't match provided shape"
self._shape: Tuple[int, ...] = shape self._shape: tuple[int, ...] = shape
low_precision = get_precision(low.dtype) low_precision = get_precision(low.dtype)
high_precision = get_precision(high.dtype) high_precision = get_precision(high.dtype)
@@ -112,7 +121,7 @@ class Box(Space[np.ndarray]):
super().__init__(self.shape, self.dtype, seed) super().__init__(self.shape, self.dtype, seed)
@property @property
def shape(self) -> Tuple[int, ...]: def shape(self) -> tuple[int, ...]:
"""Has stricter type than gym.Space - never None.""" """Has stricter type than gym.Space - never None."""
return self._shape return self._shape
@@ -122,6 +131,9 @@ class Box(Space[np.ndarray]):
Args: Args:
manner (str): One of ``"both"``, ``"below"``, ``"above"``. manner (str): One of ``"both"``, ``"below"``, ``"above"``.
Returns:
If the space is bounded
Raises: Raises:
ValueError: If `manner` is neither ``"both"`` nor ``"below"`` or ``"above"`` ValueError: If `manner` is neither ``"both"`` nor ``"below"`` or ``"above"``
""" """
@@ -146,6 +158,9 @@ class Box(Space[np.ndarray]):
* :math:`[a, \infty)` : shifted exponential distribution * :math:`[a, \infty)` : shifted exponential distribution
* :math:`(-\infty, b]` : shifted negative exponential distribution * :math:`(-\infty, b]` : shifted negative exponential distribution
* :math:`(-\infty, \infty)` : normal distribution * :math:`(-\infty, \infty)` : normal distribution
Returns:
A sampled value from the Box
""" """
high = self.high if self.dtype.kind == "f" else self.high.astype("int64") + 1 high = self.high if self.dtype.kind == "f" else self.high.astype("int64") + 1
sample = np.empty(self.shape) sample = np.empty(self.shape)
@@ -204,6 +219,9 @@ class Box(Space[np.ndarray]):
The representation will include bounds, shape and dtype. The representation will include bounds, shape and dtype.
If a bound is uniform, only the corresponding scalar will be given to avoid redundant and ugly strings. If a bound is uniform, only the corresponding scalar will be given to avoid redundant and ugly strings.
Returns:
A representation of the space
""" """
return f"Box({self.low_repr}, {self.high_repr}, {self.shape}, {self.dtype})" return f"Box({self.low_repr}, {self.high_repr}, {self.shape}, {self.dtype})"
@@ -223,6 +241,13 @@ def get_inf(dtype, sign: str) -> SupportsFloat:
Args: Args:
dtype: An `np.dtype` dtype: An `np.dtype`
sign (str): must be either `"+"` or `"-"` sign (str): must be either `"+"` or `"-"`
Returns:
Gets an infinite value with the sign and dtype
Raises:
TypeError: Unknown sign, use either '+' or '-'
ValueError: Unknown dtype for infinite bounds
""" """
if np.dtype(dtype).kind == "f": if np.dtype(dtype).kind == "f":
if sign == "+": if sign == "+":

View File

@@ -143,6 +143,9 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
"""Generates a single random sample from this space. """Generates a single random sample from this space.
The sample is an ordered dictionary of independent samples from the constituent spaces. The sample is an ordered dictionary of independent samples from the constituent spaces.
Returns:
A dictionary with the same key and sampled values from :attr:`self.spaces`
""" """
return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()]) return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])
@@ -157,11 +160,11 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
return False return False
return True return True
def __getitem__(self, key): def __getitem__(self, key: str) -> Space:
"""Get the space that is associated to `key`.""" """Get the space that is associated to `key`."""
return self.spaces[key] return self.spaces[key]
def __setitem__(self, key, value): def __setitem__(self, key: str, value: Space):
"""Set the space that is associated to `key`.""" """Set the space that is associated to `key`."""
self.spaces[key] = value self.spaces[key] = value
@@ -175,11 +178,7 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
def __repr__(self) -> str: def __repr__(self) -> str:
"""Gives a string representation of this space.""" """Gives a string representation of this space."""
return ( return "Dict(" + ", ".join([f"{k}: {s}" for k, s in self.spaces.items()]) + ")"
"Dict("
+ ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()])
+ ")"
)
def to_jsonable(self, sample_n: list) -> dict: def to_jsonable(self, sample_n: list) -> dict:
"""Convert a batch of samples from this space to a JSONable data type.""" """Convert a batch of samples from this space to a JSONable data type."""

View File

@@ -45,6 +45,9 @@ class Discrete(Space[int]):
"""Generates a single random sample from this space. """Generates a single random sample from this space.
A sample will be chosen uniformly at random. A sample will be chosen uniformly at random.
Returns:
A sampled integer from the space
""" """
return int(self.start + self.np_random.integers(self.n)) return int(self.start + self.np_random.integers(self.n))
@@ -78,6 +81,9 @@ class Discrete(Space[int]):
"""Used when loading a pickled space. """Used when loading a pickled space.
This method has to be implemented explicitly to allow for loading of legacy states. This method has to be implemented explicitly to allow for loading of legacy states.
Args:
state: The new state
""" """
super().__setstate__(state) super().__setstate__(state)

View File

@@ -57,6 +57,9 @@ class MultiBinary(Space[np.ndarray]):
"""Generates a single random sample from this space. """Generates a single random sample from this space.
A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space). A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
Returns:
Sampled values from space
""" """
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype) return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)

View File

@@ -1,7 +1,7 @@
"""Implementation of the `Space` metaclass.""" """Implementation of the `Space` metaclass."""
from __future__ import annotations from __future__ import annotations
from typing import Generic, Iterable, Mapping, Optional, Sequence, Type, TypeVar from typing import Generic, Iterable, Mapping, Optional, Sequence, TypeVar
import numpy as np import numpy as np
@@ -16,8 +16,10 @@ class Space(Generic[T_cov]):
Spaces are crucially used in Gym to define the format of valid actions and observations. Spaces are crucially used in Gym to define the format of valid actions and observations.
They serve various purposes: They serve various purposes:
* They clearly define how to interact with environments, i.e. they specify what actions need to look like and what observations will look like * They clearly define how to interact with environments, i.e. they specify what actions need to look like
* They allow us to work with highly structured data (e.g. in the form of elements of :class:`Dict` spaces) and painlessly transform them into flat arrays that can be used in learning code and what observations will look like
* They allow us to work with highly structured data (e.g. in the form of elements of :class:`Dict` spaces)
and painlessly transform them into flat arrays that can be used in learning code
* They provide a method to sample random elements. This is especially useful for exploration and debugging. * They provide a method to sample random elements. This is especially useful for exploration and debugging.
Different spaces can be combined hierarchically via container spaces (:class:`Tuple` and :class:`Dict`) to build a Different spaces can be combined hierarchically via container spaces (:class:`Tuple` and :class:`Dict`) to build a
@@ -37,7 +39,7 @@ class Space(Generic[T_cov]):
def __init__( def __init__(
self, self,
shape: Optional[Sequence[int]] = None, shape: Optional[Sequence[int]] = None,
dtype: Optional[Type | str] = None, dtype: Optional[type | str | np.dtype] = None,
seed: Optional[int | seeding.RandomNumberGenerator] = None, seed: Optional[int | seeding.RandomNumberGenerator] = None,
): ):
"""Constructor of :class:`Space`. """Constructor of :class:`Space`.
@@ -90,6 +92,9 @@ class Space(Generic[T_cov]):
"""Used when loading a pickled space. """Used when loading a pickled space.
This method was implemented explicitly to allow for loading of legacy states. This method was implemented explicitly to allow for loading of legacy states.
Args:
state: The updated state value
""" """
# Don't mutate the original state # Don't mutate the original state
state = dict(state) state = dict(state)

View File

@@ -79,6 +79,9 @@ class Tuple(Space[tuple], Sequence):
"""Generates a single random sample inside this space. """Generates a single random sample inside this space.
This method draws independent samples from the subspaces. This method draws independent samples from the subspaces.
Returns:
Tuple of the subspace's samples
""" """
return tuple(space.sample() for space in self.spaces) return tuple(space.sample() for space in self.spaces)

View File

@@ -1,6 +1,7 @@
"""Implementation of utility functions that can be applied to spaces. """Implementation of utility functions that can be applied to spaces.
These functions mostly take care of flattening and unflattening elements of spaces to facilitate their usage in learning code. These functions mostly take care of flattening and unflattening elements of spaces
to facilitate their usage in learning code.
""" """
from __future__ import annotations from __future__ import annotations
@@ -18,17 +19,21 @@ from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Space, T
def flatdim(space: Space) -> int: def flatdim(space: Space) -> int:
"""Return the number of dimensions a flattened equivalent of this space would have. """Return the number of dimensions a flattened equivalent of this space would have.
Accepts a space and returns an integer.
Raises:
NotImplementedError: if the space is not defined in ``gym.spaces``.
Example usage:: Example usage::
>>> from gym.spaces import Discrete >>> from gym.spaces import Discrete
>>> space = Dict({"position": Discrete(2), "velocity": Discrete(3)}) >>> space = Dict({"position": Discrete(2), "velocity": Discrete(3)})
>>> flatdim(space) >>> flatdim(space)
5 5
Args:
space: The space to return the number of dimensions of the flattened spaces
Returns:
The number of dimensions for the flattened spaces
Raises:
NotImplementedError: if the space is not defined in ``gym.spaces``.
""" """
raise NotImplementedError(f"Unknown space: `{space}`") raise NotImplementedError(f"Unknown space: `{space}`")
@@ -69,9 +74,15 @@ def flatten(space: Space[T], x: T) -> np.ndarray:
This is useful when e.g. points from spaces must be passed to a neural This is useful when e.g. points from spaces must be passed to a neural
network, which only understands flat arrays of floats. network, which only understands flat arrays of floats.
Accepts a space and a point from that space. Always returns a 1D array. Args:
Raises ``NotImplementedError`` if the space is not defined in space: The space that ``x`` is flattened by
``gym.spaces``. x: The value to flatten
Returns:
The flattened ``x``, always returns a 1D array.
Raises:
NotImplementedError: If the space is not defined in ``gym.spaces``.
""" """
raise NotImplementedError(f"Unknown space: `{space}`") raise NotImplementedError(f"Unknown space: `{space}`")
@@ -116,9 +127,15 @@ def unflatten(space: Space[T], x: np.ndarray) -> T:
This reverses the transformation applied by :func:`flatten`. You must ensure This reverses the transformation applied by :func:`flatten`. You must ensure
that the ``space`` argument is the same as for the :func:`flatten` call. that the ``space`` argument is the same as for the :func:`flatten` call.
Accepts a space and a flattened point. Returns a point with a structure Args:
that matches the space. Raises ``NotImplementedError`` if the space is not space: The space used to unflatten ``x``
defined in ``gym.spaces``. x: The array to unflatten
Returns:
A point with a structure that matches the space.
Raises:
NotImplementedError: if the space is not defined in ``gym.spaces``.
""" """
raise NotImplementedError(f"Unknown space: `{space}`") raise NotImplementedError(f"Unknown space: `{space}`")
@@ -173,9 +190,6 @@ def flatten_space(space: Space) -> Box:
:func:`flatdim` dimensions. Flattening a sample of the original space :func:`flatdim` dimensions. Flattening a sample of the original space
has the same effect as taking a sample of the flattenend space. has the same effect as taking a sample of the flattenend space.
Raises ``NotImplementedError`` if the space is not defined in
``gym.spaces``.
Example:: Example::
>>> box = Box(0.0, 1.0, shape=(3, 4, 5)) >>> box = Box(0.0, 1.0, shape=(3, 4, 5))
@@ -201,6 +215,15 @@ def flatten_space(space: Space) -> Box:
Box(6,) Box(6,)
>>> flatten(space, space.sample()) in flatten_space(space) >>> flatten(space, space.sample()) in flatten_space(space)
True True
Args:
space: The space to flatten
Returns:
A flattened Box
Raises:
NotImplementedError: if the space is not defined in ``gym.spaces``.
""" """
raise NotImplementedError(f"Unknown space: `{space}`") raise NotImplementedError(f"Unknown space: `{space}`")

View File

@@ -210,6 +210,10 @@ def _check_returned_values(env: gym.Env, observation_space: Space, action_space:
env: The environment env: The environment
observation_space: The environment's observation space observation_space: The environment's observation space
action_space: The environment's action space action_space: The environment's action space
Raises:
AssertionError: If the ``observation_space`` is :class:`Dict` and
keys from :meth:`Env.reset` are not in the observation space
""" """
# because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists # because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists
obs = env.reset() obs = env.reset()
@@ -329,6 +333,10 @@ def _check_reset_seed(env: gym.Env, seed: Optional[int] = None):
Args: Args:
env: The environment to check env: The environment to check
seed: The optional seed to use seed: The optional seed to use
Raises:
AssertionError: The environment cannot be reset with a random seed,
even though `seed` or `kwargs` appear in the signature.
""" """
signature = inspect.signature(env.reset) signature = inspect.signature(env.reset)
assert ( assert (
@@ -365,6 +373,10 @@ def _check_reset_info(env: gym.Env):
Args: Args:
env: The environment to check env: The environment to check
Raises:
AssertionError: The environment cannot be reset with `return_info=True`,
even though `return_info` or `kwargs` appear in the signature.
""" """
signature = inspect.signature(env.reset) signature = inspect.signature(env.reset)
assert ( assert (
@@ -394,6 +406,10 @@ def _check_reset_options(env: gym.Env):
Args: Args:
env: The environment to check env: The environment to check
Raises:
AssertionError: The environment cannot be reset with options,
even though `options` or `kwargs` appear in the signature.
""" """
signature = inspect.signature(env.reset) signature = inspect.signature(env.reset)
assert ( assert (

View File

@@ -62,8 +62,8 @@ class PlayableGame:
keys_to_action = self.env.unwrapped.get_keys_to_action() keys_to_action = self.env.unwrapped.get_keys_to_action()
else: else:
raise MissingKeysToAction( raise MissingKeysToAction(
"%s does not have explicit key to action mapping, " f"{self.env.spec.id} does not have explicit key to action mapping, "
"please specify one manually" % self.env.spec.id "please specify one manually"
) )
relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), [])) relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), []))
return relevant_keys return relevant_keys
@@ -81,7 +81,8 @@ class PlayableGame:
def process_event(self, event: Event): def process_event(self, event: Event):
"""Processes a PyGame event. """Processes a PyGame event.
In particular, this function is used to keep track of which buttons are currently pressed and to exit the :func:`play` function when the PyGame window is closed. In particular, this function is used to keep track of which buttons are currently pressed
and to exit the :func:`play` function when the PyGame window is closed.
Args: Args:
event: The event to process event: The event to process
@@ -258,15 +259,16 @@ class PlayPlot:
It should return a list of metrics that are computed from this data. It should return a list of metrics that are computed from this data.
For instance, the function may look like this:: For instance, the function may look like this::
def compute_metrics(obs_t, obs_tp, action, reward, done, info): >>> def compute_metrics(obs_t, obs_tp, action, reward, done, info):
return [reward, info["cumulative_reward"], np.linalg.norm(action)] ... return [reward, info["cumulative_reward"], np.linalg.norm(action)]
:class:`PlayPlot` provides the method :meth:`callback` which will pass its arguments along to that function :class:`PlayPlot` provides the method :meth:`callback` which will pass its arguments along to that function
and uses the returned values to update live plots of the metrics. and uses the returned values to update live plots of the metrics.
Typically, this :meth:`callback` will be used in conjunction with :func:`play` to see how the metrics evolve as you play:: Typically, this :meth:`callback` will be used in conjunction with :func:`play` to see how the metrics evolve as you play::
>>> plotter = PlayPlot(compute_metrics, horizon_timesteps=200, plot_names=["Immediate Rew.", "Cumulative Rew.", "Action Magnitude"]) >>> plotter = PlayPlot(compute_metrics, horizon_timesteps=200,
... plot_names=["Immediate Rew.", "Cumulative Rew.", "Action Magnitude"])
>>> play(your_env, callback=plotter.callback) >>> play(your_env, callback=plotter.callback)
""" """
@@ -282,6 +284,9 @@ class PlayPlot:
callback: Function that computes metrics from environment transitions callback: Function that computes metrics from environment transitions
horizon_timesteps: The time horizon used for the live plots horizon_timesteps: The time horizon used for the live plots
plot_names: List of plot titles plot_names: List of plot titles
Raises:
DependencyNotInstalled: If matplotlib is not installed
""" """
deprecation( deprecation(
"`PlayPlot` is marked as deprecated and will be removed in the near future." "`PlayPlot` is marked as deprecated and will be removed in the near future."

View File

@@ -20,6 +20,9 @@ def np_random(seed: Optional[int] = None) -> tuple[RandomNumberGenerator, Any]:
Returns: Returns:
The generator and resulting seed The generator and resulting seed
Raises:
Error: Seed must be a non-negative integer or omitted
""" """
if seed is not None and not (isinstance(seed, int) and 0 <= seed): if seed is not None and not (isinstance(seed, int) and 0 <= seed):
raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}") raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}")
@@ -175,6 +178,9 @@ def create_seed(a: Optional[Union[int, str]] = None, max_bytes: int = 8) -> int:
Returns: Returns:
A seed A seed
Raises:
Error: Invalid type for seed, expects None or str or int
""" """
deprecation( deprecation(
"Function `create_seed(a, max_bytes)` is marked as deprecated and will be removed in the future. " "Function `create_seed(a, max_bytes)` is marked as deprecated and will be removed in the future. "

View File

@@ -73,19 +73,30 @@ class AsyncVectorEnv(VectorEnv):
Args: Args:
env_fns: Functions that create the environments. env_fns: Functions that create the environments.
observation_space: Observation space of a single environment. If ``None``, then the observation space of the first environment is taken. observation_space: Observation space of a single environment. If ``None``,
action_space: Action space of a single environment. If ``None``, then the action space of the first environment is taken. then the observation space of the first environment is taken.
shared_memory: If ``True``, then the observations from the worker processes are communicated back through shared variables. This can improve the efficiency if the observations are large (e.g. images). action_space: Action space of a single environment. If ``None``,
copy: If ``True``, then the :meth:`~AsyncVectorEnv.reset` and :meth:`~AsyncVectorEnv.step` methods return a copy of the observations. then the action space of the first environment is taken.
shared_memory: If ``True``, then the observations from the worker processes are communicated back through
shared variables. This can improve the efficiency if the observations are large (e.g. images).
copy: If ``True``, then the :meth:`~AsyncVectorEnv.reset` and :meth:`~AsyncVectorEnv.step` methods
return a copy of the observations.
context: Context for `multiprocessing`_. If ``None``, then the default context is used. context: Context for `multiprocessing`_. If ``None``, then the default context is used.
daemon: If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they will quit if the head process quits. However, ``daemon=True`` prevents subprocesses to spawn children, so for some environments you may want to have it set to ``False``. daemon: If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they will quit if
worker: If set, then use that worker in a subprocess instead of a default one. Can be useful to override some inner vector env logic, for instance, how resets on done are handled. the head process quits. However, ``daemon=True`` prevents subprocesses to spawn children,
so for some environments you may want to have it set to ``False``.
worker: If set, then use that worker in a subprocess instead of a default one.
Can be useful to override some inner vector env logic, for instance, how resets on done are handled.
Warnings: worker is an advanced mode option. It provides a high degree of flexibility and a high chance to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start from the code for ``_worker`` (or ``_worker_shared_memory``) method, and add changes. Warnings: worker is an advanced mode option. It provides a high degree of flexibility and a high chance
to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start
from the code for ``_worker`` (or ``_worker_shared_memory``) method, and add changes.
Raises: Raises:
RuntimeError: If the observation space of some sub-environment does not match observation_space (or, by default, the observation space of the first sub-environment). RuntimeError: If the observation space of some sub-environment does not match observation_space
ValueError: If observation_space is a custom space (i.e. not a default space in Gym, such as gym.spaces.Box, gym.spaces.Discrete, or gym.spaces.Dict) and shared_memory is True. (or, by default, the observation space of the first sub-environment).
ValueError: If observation_space is a custom space (i.e. not a default space in Gym,
such as gym.spaces.Box, gym.spaces.Discrete, or gym.spaces.Dict) and shared_memory is True.
""" """
ctx = mp.get_context(context) ctx = mp.get_context(context)
self.env_fns = env_fns self.env_fns = env_fns
@@ -163,6 +174,9 @@ class AsyncVectorEnv(VectorEnv):
Args: Args:
seed: The seeds use with the environments seed: The seeds use with the environments
Raises:
AlreadyPendingCallError: Calling `seed` while waiting for a pending call to complete
""" """
super().seed(seed=seed) super().seed(seed=seed)
self._assert_is_running() self._assert_is_running()
@@ -382,6 +396,10 @@ class AsyncVectorEnv(VectorEnv):
name: Name of the method or property to call. name: Name of the method or property to call.
*args: Arguments to apply to the method call. *args: Arguments to apply to the method call.
**kwargs: Keyword arguments to apply to the method call. **kwargs: Keyword arguments to apply to the method call.
Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
AlreadyPendingCallError: Calling `call_async` while waiting for a pending call to complete
""" """
self._assert_is_running() self._assert_is_running()
if self._state != AsyncState.DEFAULT: if self._state != AsyncState.DEFAULT:
@@ -399,10 +417,15 @@ class AsyncVectorEnv(VectorEnv):
"""Calls all parent pipes and waits for the results. """Calls all parent pipes and waits for the results.
Args: Args:
timeout: Number of seconds before the call to `step_wait` times out. If `None` (default), the call to `step_wait` never times out. timeout: Number of seconds before the call to `step_wait` times out.
If `None` (default), the call to `step_wait` never times out.
Returns: Returns:
List of the results of the individual calls to the method or property for each environment. List of the results of the individual calls to the method or property for each environment.
Raises:
NoAsyncCallError: Calling `call_wait` without any prior call to `call_async`.
TimeoutError: The call to `call_wait` has timed out after timeout second(s).
""" """
self._assert_is_running() self._assert_is_running()
if self._state != AsyncState.WAITING_CALL: if self._state != AsyncState.WAITING_CALL:
@@ -431,6 +454,10 @@ class AsyncVectorEnv(VectorEnv):
values: Values of the property to be set to. If ``values`` is a list or values: Values of the property to be set to. If ``values`` is a list or
tuple, then it corresponds to the values for each individual tuple, then it corresponds to the values for each individual
environment, otherwise a single value is set for all environments. environment, otherwise a single value is set for all environments.
Raises:
ValueError: Values must be a list or tuple with length equal to the number of environments.
AlreadyPendingCallError: Calling `set_attr` while waiting for a pending call to complete.
""" """
self._assert_is_running() self._assert_is_running()
if not isinstance(values, (list, tuple)): if not isinstance(values, (list, tuple)):

View File

@@ -39,12 +39,15 @@ class SyncVectorEnv(VectorEnv):
Args: Args:
env_fns: iterable of callable functions that create the environments. env_fns: iterable of callable functions that create the environments.
observation_space: Observation space of a single environment. If ``None``, then the observation space of the first environment is taken. observation_space: Observation space of a single environment. If ``None``,
action_space: Action space of a single environment. If ``None``, then the action space of the first environment is taken. then the observation space of the first environment is taken.
action_space: Action space of a single environment. If ``None``,
then the action space of the first environment is taken.
copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations. copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations.
Raises: Raises:
RuntimeError: If the observation space of some sub-environment does not match observation_space (or, by default, the observation space of the first sub-environment). RuntimeError: If the observation space of some sub-environment does not match observation_space
(or, by default, the observation space of the first sub-environment).
""" """
self.env_fns = env_fns self.env_fns = env_fns
self.envs = [env_fn() for env_fn in env_fns] self.envs = [env_fn() for env_fn in env_fns]
@@ -195,6 +198,9 @@ class SyncVectorEnv(VectorEnv):
values: Values of the property to be set to. If ``values`` is a list or values: Values of the property to be set to. If ``values`` is a list or
tuple, then it corresponds to the values for each individual tuple, then it corresponds to the values for each individual
environment, otherwise, a single value is set for all environments. environment, otherwise, a single value is set for all environments.
Raises:
ValueError: Values must be a list or tuple with length equal to the number of environments.
""" """
if not isinstance(values, (list, tuple)): if not isinstance(values, (list, tuple)):
values = [values for _ in range(self.num_envs)] values = [values for _ in range(self.num_envs)]

View File

@@ -39,6 +39,9 @@ def clear_mpi_env_vars():
This context manager is a hacky way to clear those environment variables This context manager is a hacky way to clear those environment variables
temporarily such as when we are starting multiprocessing Processes. temporarily such as when we are starting multiprocessing Processes.
Yields:
Yields for the context manager
""" """
removed_environment = {} removed_environment = {}
for k, v in list(os.environ.items()): for k, v in list(os.environ.items()):

View File

@@ -33,6 +33,9 @@ def concatenate(
Returns: Returns:
The output object. This object is a (possibly nested) numpy array. The output object. This object is a (possibly nested) numpy array.
Raises:
ValueError: Space is not a valid :class:`gym.Space` instance
""" """
raise ValueError( raise ValueError(
f"Space of type `{type(space)}` is not a valid `gym.Space` instance." f"Space of type `{type(space)}` is not a valid `gym.Space` instance."
@@ -95,6 +98,9 @@ def create_empty_array(
Returns: Returns:
The output object. This object is a (possibly nested) numpy array. The output object. This object is a (possibly nested) numpy array.
Raises:
ValueError: Space is not a valid :class:`gym.Space` instance
""" """
raise ValueError( raise ValueError(
f"Space of type `{type(space)}` is not a valid `gym.Space` instance." f"Space of type `{type(space)}` is not a valid `gym.Space` instance."

View File

@@ -28,6 +28,9 @@ def create_shared_memory(
Returns: Returns:
shared_memory for the shared object across processes. shared_memory for the shared object across processes.
Raises:
CustomSpaceError: Space is not a valid :class:`gym.Space` instance
""" """
raise CustomSpaceError( raise CustomSpaceError(
"Cannot create a shared memory for space with " "Cannot create a shared memory for space with "
@@ -86,6 +89,8 @@ def read_from_shared_memory(
Returns: Returns:
Batch of observations as a (possibly nested) numpy array. Batch of observations as a (possibly nested) numpy array.
Raises:
CustomSpaceError: Space is not a valid :class:`gym.Space` instance
""" """
raise CustomSpaceError( raise CustomSpaceError(
"Cannot read from a shared memory for space with " "Cannot read from a shared memory for space with "
@@ -137,7 +142,11 @@ def write_to_shared_memory(
space: Observation space of a single environment in the vectorized environment. space: Observation space of a single environment in the vectorized environment.
index: Index of the environment (must be in `[0, num_envs)`). index: Index of the environment (must be in `[0, num_envs)`).
value: Observation of the single environment to write to shared memory. value: Observation of the single environment to write to shared memory.
shared_memory: Shared object across processes. This contains the observations from the vectorized environment. This object is created with `create_shared_memory`. shared_memory: Shared object across processes. This contains the observations from the vectorized environment.
This object is created with `create_shared_memory`.
Raises:
CustomSpaceError: Space is not a valid :class:`gym.Space` instance
""" """
raise CustomSpaceError( raise CustomSpaceError(
"Cannot write to a shared memory for space with " "Cannot write to a shared memory for space with "

View File

@@ -33,6 +33,9 @@ def batch_space(space: Space, n: int = 1) -> Space:
Returns: Returns:
Space (e.g. the observation space) for a batch of environments in the vectorized environment. Space (e.g. the observation space) for a batch of environments in the vectorized environment.
Raises:
ValueError: Cannot batch space that is not a valid :class:`gym.Space` instance
""" """
raise ValueError( raise ValueError(
f"Cannot batch space with type `{type(space)}`. The space must be a valid `gym.Space` instance." f"Cannot batch space with type `{type(space)}`. The space must be a valid `gym.Space` instance."
@@ -147,9 +150,12 @@ def iterate(space: Space, items) -> Iterator:
Returns: Returns:
Iterator over the elements in `items`. Iterator over the elements in `items`.
Raises:
ValueError: Space is not an instance of :class:`gym.Space`
""" """
raise ValueError( raise ValueError(
f"Space of type `{type(space)}` is not a valid `gym.Space` " "instance." f"Space of type `{type(space)}` is not a valid `gym.Space` instance."
) )

View File

@@ -56,7 +56,13 @@ class VectorEnv(gym.Env):
): ):
"""Reset the sub-environments asynchronously. """Reset the sub-environments asynchronously.
This method will return ``None``. A call to :meth:`reset_async` should be followed by a call to :meth:`reset_wait` to retrieve the results. This method will return ``None``. A call to :meth:`reset_async` should be followed
by a call to :meth:`reset_wait` to retrieve the results.
Args:
seed: The reset seed
return_info: If to return info
options: Reset options
""" """
pass pass
@@ -69,8 +75,18 @@ class VectorEnv(gym.Env):
"""Retrieves the results of a :meth:`reset_async` call. """Retrieves the results of a :meth:`reset_async` call.
A call to this method must always be preceded by a call to :meth:`reset_async`. A call to this method must always be preceded by a call to :meth:`reset_async`.
Args:
seed: The reset seed
return_info: If to return info
options: Reset options
Returns:
The results from :meth:`reset_async`
Raises:
NotImplementedError: VectorEnv does not implement function
""" """
raise NotImplementedError()
def reset( def reset(
self, self,
@@ -96,15 +112,22 @@ class VectorEnv(gym.Env):
"""Asynchronously performs steps in the sub-environments. """Asynchronously performs steps in the sub-environments.
The results can be retrieved via a call to :meth:`step_wait`. The results can be retrieved via a call to :meth:`step_wait`.
Args:
actions: The actions to take asynchronously
""" """
pass
def step_wait(self, **kwargs): def step_wait(self, **kwargs):
"""Retrieves the results of a :meth:`step_async` call. """Retrieves the results of a :meth:`step_async` call.
A call to this method must always be preceded by a call to :meth:`step_async`. A call to this method must always be preceded by a call to :meth:`step_async`.
Args:
**kwargs: Additional keywords for vector implementation
Returns:
The results from the :meth:`step_async` call
""" """
raise NotImplementedError()
def step(self, actions): def step(self, actions):
"""Take an action for each parallel environment. """Take an action for each parallel environment.
@@ -120,11 +143,9 @@ class VectorEnv(gym.Env):
def call_async(self, name, *args, **kwargs): def call_async(self, name, *args, **kwargs):
"""Calls a method name for each parallel environment asynchronously.""" """Calls a method name for each parallel environment asynchronously."""
pass
def call_wait(self, **kwargs): def call_wait(self, **kwargs) -> list[Any]:
"""After calling a method in :meth:`call_async`, this function collects the results.""" """After calling a method in :meth:`call_async`, this function collects the results."""
raise NotImplementedError()
def call(self, name: str, *args, **kwargs) -> list[Any]: def call(self, name: str, *args, **kwargs) -> list[Any]:
"""Call a method, or get a property, from each parallel environment. """Call a method, or get a property, from each parallel environment.
@@ -160,7 +181,6 @@ class VectorEnv(gym.Env):
tuple, then it corresponds to the values for each individual environment, otherwise a single value tuple, then it corresponds to the values for each individual environment, otherwise a single value
is set for all environments. is set for all environments.
""" """
raise NotImplementedError()
def close_extras(self, **kwargs): def close_extras(self, **kwargs):
"""Clean up the extra resources e.g. beyond what's in this base class.""" """Clean up the extra resources e.g. beyond what's in this base class."""
@@ -180,6 +200,8 @@ class VectorEnv(gym.Env):
Notes: Notes:
This will be automatically called when garbage collected or program exited. This will be automatically called when garbage collected or program exited.
Args:
**kwargs: Keyword arguments passed to :meth:`close_extras`
""" """
if self.closed: if self.closed:
return return
@@ -260,8 +282,12 @@ class VectorEnv(gym.Env):
if not getattr(self, "closed", True): if not getattr(self, "closed", True):
self.close() self.close()
def __repr__(self): def __repr__(self) -> str:
"""Returns a string representation of the vector environment using the class name, number of environments and environment spec id.""" """Returns a string representation of the vector environment.
Returns:
A string containing the class name, number of environments and environment spec id
"""
if self.spec is None: if self.spec is None:
return f"{self.__class__.__name__}({self.num_envs})" return f"{self.__class__.__name__}({self.num_envs})"
else: else:

View File

@@ -54,6 +54,10 @@ class AtariPreprocessing(gym.Wrapper):
grayscale observations to make them 3-dimensional. grayscale observations to make them 3-dimensional.
scale_obs (bool): if True, then observation normalized in range [0,1) is returned. It also limits memory scale_obs (bool): if True, then observation normalized in range [0,1) is returned. It also limits memory
optimization benefits of FrameStack Wrapper. optimization benefits of FrameStack Wrapper.
Raises:
DependencyNotInstalled: opencv-python package not installed
ValueError: Disable frame-skipping in the original env
""" """
super().__init__(env) super().__init__(env)
if cv2 is None: if cv2 is None:

View File

@@ -26,6 +26,9 @@ class LazyFrames:
Args: Args:
frames (list): The frames to convert to lazy frames frames (list): The frames to convert to lazy frames
lz4_compress (bool): Use lz4 to compress the frames internally lz4_compress (bool): Use lz4 to compress the frames internally
Raises:
DependencyNotInstalled: lz4 is not installed
""" """
self.frame_shape = tuple(frames[0].shape) self.frame_shape = tuple(frames[0].shape)
self.shape = (len(frames),) + self.frame_shape self.shape = (len(frames),) + self.frame_shape

View File

@@ -46,6 +46,10 @@ class VideoRecorder:
metadata (Optional[dict]): Contents to save to the metadata file. metadata (Optional[dict]): Contents to save to the metadata file.
enabled (bool): Whether to actually record video, or just no-op (for convenience) enabled (bool): Whether to actually record video, or just no-op (for convenience)
base_path (Optional[str]): Alternatively, path to the video file without extension, which will be added. base_path (Optional[str]): Alternatively, path to the video file without extension, which will be added.
Raises:
Error: You can pass at most one of `path` or `base_path`
Error: Invalid path given that must have a particular file extension
""" """
modes = env.metadata.get("render_modes", []) modes = env.metadata.get("render_modes", [])
@@ -281,6 +285,9 @@ class TextEncoder:
Args: Args:
frame: A string or StringIO frame frame: A string or StringIO frame
Raises:
InvalidFrame: Wrong type for a frame, expects text frame to be a string or StringIO
""" """
if isinstance(frame, str): if isinstance(frame, str):
string = frame string = frame
@@ -366,6 +373,10 @@ class ImageEncoder:
frame_shape: The expected frame shape, a tuple of height, weight and channels (3 or 4) frame_shape: The expected frame shape, a tuple of height, weight and channels (3 or 4)
frames_per_sec: The number of frames per second the environment runs at frames_per_sec: The number of frames per second the environment runs at
output_frames_per_sec: The output number of frames per second for the video output_frames_per_sec: The output number of frames per second for the video
Raises:
InvalidFrame: Expects frame to have shape (w,h,3) or (w,h,4)
DependencyNotInstalled: Found neither the ffmpeg nor avconv executables.
""" """
self.proc = None self.proc = None
self.output_path = output_path self.output_path = output_path

View File

@@ -77,6 +77,7 @@ class PixelObservationWrapper(gym.ObservationWrapper):
arrays. arrays.
ValueError: If ``env``'s observation already contains any of the ValueError: If ``env``'s observation already contains any of the
specified ``pixel_keys``. specified ``pixel_keys``.
TypeError: When an unexpected pixel type is used
""" """
super().__init__(env) super().__init__(env)

View File

@@ -11,6 +11,12 @@ def capped_cubic_video_schedule(episode_id: int) -> bool:
"""The default episode trigger. """The default episode trigger.
This function will trigger recordings at the episode indices 0, 1, 4, 8, 27, ..., :math:`k^3`, ..., 729, 1000, 2000, 3000, ... This function will trigger recordings at the episode indices 0, 1, 4, 8, 27, ..., :math:`k^3`, ..., 729, 1000, 2000, 3000, ...
Args:
episode_id: The episode number
Returns:
If to apply a video schedule number
""" """
if episode_id < 1000: if episode_id < 1000:
return int(round(episode_id ** (1.0 / 3))) ** 3 == episode_id return int(round(episode_id ** (1.0 / 3))) ** 3 == episode_id

View File

@@ -50,6 +50,9 @@ class ResizeObservation(gym.ObservationWrapper):
Returns: Returns:
The reshaped observations The reshaped observations
Raises:
DependencyNotInstalled: opencv-python is not installed
""" """
try: try:
import cv2 import cv2