mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-09-02 18:36:16 +00:00
Updated docstrings using darglint (#2827)
* Updated docstrings using darglint, ignoring 402 and 202 plus shortened lines into multiple where they were overflowing * Remove abstract method decorators, for a future PR * Add __future__ import annotation for python 3.7+ notion * Added missing bracket * Fix minor docstring tables
This commit is contained in:
15
gym/core.py
15
gym/core.py
@@ -1,7 +1,6 @@
|
|||||||
"""Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper."""
|
"""Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import abstractmethod
|
|
||||||
from typing import Generic, Optional, SupportsFloat, TypeVar, Union
|
from typing import Generic, Optional, SupportsFloat, TypeVar, Union
|
||||||
|
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
@@ -63,7 +62,6 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
def np_random(self, value: RandomNumberGenerator):
|
def np_random(self, value: RandomNumberGenerator):
|
||||||
self._np_random = value
|
self._np_random = value
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def step(self, action: ActType) -> tuple[ObsType, float, bool, dict]:
|
def step(self, action: ActType) -> tuple[ObsType, float, bool, dict]:
|
||||||
"""Run one timestep of the environment's dynamics.
|
"""Run one timestep of the environment's dynamics.
|
||||||
|
|
||||||
@@ -71,7 +69,7 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
Accepts an action and returns a tuple `(observation, reward, done, info)`.
|
Accepts an action and returns a tuple `(observation, reward, done, info)`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
action (object): an action provided by the agent
|
action (ActType): an action provided by the agent
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
observation (object): this will be an element of the environment's :attr:`observation_space`.
|
observation (object): this will be an element of the environment's :attr:`observation_space`.
|
||||||
@@ -88,7 +86,6 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def reset(
|
def reset(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -129,7 +126,6 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
if seed is not None:
|
if seed is not None:
|
||||||
self._np_random, seed = seeding.np_random(seed)
|
self._np_random, seed = seeding.np_random(seed)
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def render(self, mode="human"):
|
def render(self, mode="human"):
|
||||||
"""Renders the environment.
|
"""Renders the environment.
|
||||||
|
|
||||||
@@ -152,6 +148,7 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
in implementations to use the functionality of this method.
|
in implementations to use the functionality of this method.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
>>> import numpy as np
|
||||||
>>> class MyEnv(Env):
|
>>> class MyEnv(Env):
|
||||||
... metadata = {'render_modes': ['human', 'rgb_array']}
|
... metadata = {'render_modes': ['human', 'rgb_array']}
|
||||||
...
|
...
|
||||||
@@ -161,7 +158,7 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
... elif mode == 'human':
|
... elif mode == 'human':
|
||||||
... ... # pop up a window and render
|
... ... # pop up a window and render
|
||||||
... else:
|
... else:
|
||||||
... super(MyEnv, self).render(mode=mode) # just raise an exception
|
... super().render(mode=mode) # just raise an exception
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mode: the mode to render with, valid modes are `env.metadata["render_modes"]`
|
mode: the mode to render with, valid modes are `env.metadata["render_modes"]`
|
||||||
@@ -208,7 +205,7 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
"""Returns the base non-wrapped environment.
|
"""Returns the base non-wrapped environment.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
gym.Env: The base non-wrapped gym.Env instance
|
Env: The base non-wrapped gym.Env instance
|
||||||
"""
|
"""
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -389,7 +386,6 @@ class ObservationWrapper(Wrapper):
|
|||||||
observation, reward, done, info = self.env.step(action)
|
observation, reward, done, info = self.env.step(action)
|
||||||
return self.observation(observation), reward, done, info
|
return self.observation(observation), reward, done, info
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
"""Returns a modified observation."""
|
"""Returns a modified observation."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -424,7 +420,6 @@ class RewardWrapper(Wrapper):
|
|||||||
observation, reward, done, info = self.env.step(action)
|
observation, reward, done, info = self.env.step(action)
|
||||||
return observation, self.reward(reward), done, info
|
return observation, self.reward(reward), done, info
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def reward(self, reward):
|
def reward(self, reward):
|
||||||
"""Returns a modified ``reward``."""
|
"""Returns a modified ``reward``."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -466,12 +461,10 @@ class ActionWrapper(Wrapper):
|
|||||||
"""Runs the environment :meth:`env.step` using the modified ``action`` from :meth:`self.action`."""
|
"""Runs the environment :meth:`env.step` using the modified ``action`` from :meth:`self.action`."""
|
||||||
return self.env.step(self.action(action))
|
return self.env.step(self.action(action))
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def action(self, action):
|
def action(self, action):
|
||||||
"""Returns a modified action before :meth:`env.step` is called."""
|
"""Returns a modified action before :meth:`env.step` is called."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def reverse_action(self, action):
|
def reverse_action(self, action):
|
||||||
"""Returns a reversed ``action``."""
|
"""Returns a reversed ``action``."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@@ -704,7 +704,8 @@ def heuristic(env, s):
|
|||||||
s[5] is the angular speed
|
s[5] is the angular speed
|
||||||
s[6] 1 if first leg has contact, else 0
|
s[6] 1 if first leg has contact, else 0
|
||||||
s[7] 1 if second leg has contact, else 0
|
s[7] 1 if second leg has contact, else 0
|
||||||
returns:
|
|
||||||
|
Returns:
|
||||||
a: The heuristic to be fed into the step function defined above to determine the next step and reward.
|
a: The heuristic to be fed into the step function defined above to determine the next step and reward.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@@ -45,7 +45,7 @@ class AcrobotEnv(core.Env):
|
|||||||
joint between the two links.
|
joint between the two links.
|
||||||
|
|
||||||
| Num | Action | Unit |
|
| Num | Action | Unit |
|
||||||
|----|-------------------------------------------|---------------|
|
|-----|---------------------------------------|--------------|
|
||||||
| 0 | apply -1 torque to the actuated joint | torque (N m) |
|
| 0 | apply -1 torque to the actuated joint | torque (N m) |
|
||||||
| 1 | apply 0 torque to the actuated joint | torque (N m) |
|
| 1 | apply 0 torque to the actuated joint | torque (N m) |
|
||||||
| 2 | apply 1 torque to the actuated joint | torque (N m) |
|
| 2 | apply 1 torque to the actuated joint | torque (N m) |
|
||||||
@@ -56,7 +56,7 @@ class AcrobotEnv(core.Env):
|
|||||||
two rotational joint angles as well as their angular velocities:
|
two rotational joint angles as well as their angular velocities:
|
||||||
|
|
||||||
| Num | Observation | Min | Max |
|
| Num | Observation | Min | Max |
|
||||||
|-----|-----------------------|----------------------|--------------------|
|
|-----|------------------------------|---------------------|-------------------|
|
||||||
| 0 | Cosine of `theta1` | -1 | 1 |
|
| 0 | Cosine of `theta1` | -1 | 1 |
|
||||||
| 1 | Sine of `theta1` | -1 | 1 |
|
| 1 | Sine of `theta1` | -1 | 1 |
|
||||||
| 2 | Cosine of `theta2` | -1 | 1 |
|
| 2 | Cosine of `theta2` | -1 | 1 |
|
||||||
@@ -67,15 +67,17 @@ class AcrobotEnv(core.Env):
|
|||||||
where
|
where
|
||||||
- `theta1` is the angle of the first joint, where an angle of 0 indicates the first link is pointing directly
|
- `theta1` is the angle of the first joint, where an angle of 0 indicates the first link is pointing directly
|
||||||
downwards.
|
downwards.
|
||||||
- `theta2` is ***relative to the angle of the first link.*** An angle of 0 corresponds to having the same angle between the
|
- `theta2` is ***relative to the angle of the first link.***
|
||||||
two links.
|
An angle of 0 corresponds to having the same angle between the two links.
|
||||||
|
|
||||||
The angular velocities of `theta1` and `theta2` are bounded at ±4π, and ±9π rad/s respectively.
|
The angular velocities of `theta1` and `theta2` are bounded at ±4π, and ±9π rad/s respectively.
|
||||||
A state of `[1, 0, 1, 0, ..., ...]` indicates that both links are pointing downwards.
|
A state of `[1, 0, 1, 0, ..., ...]` indicates that both links are pointing downwards.
|
||||||
|
|
||||||
### Rewards
|
### Rewards
|
||||||
|
|
||||||
The goal is to have the free end reach a designated target height in as few steps as possible, and as such all steps that do not reach the goal incur a reward of -1. Achieving the target height results in termination with a reward of 0. The reward threshold is -100.
|
The goal is to have the free end reach a designated target height in as few steps as possible,
|
||||||
|
and as such all steps that do not reach the goal incur a reward of -1.
|
||||||
|
Achieving the target height results in termination with a reward of 0. The reward threshold is -100.
|
||||||
|
|
||||||
### Starting State
|
### Starting State
|
||||||
|
|
||||||
@@ -98,7 +100,8 @@ class AcrobotEnv(core.Env):
|
|||||||
```
|
```
|
||||||
|
|
||||||
By default, the dynamics of the acrobot follow those described in Sutton and Barto's book
|
By default, the dynamics of the acrobot follow those described in Sutton and Barto's book
|
||||||
[Reinforcement Learning: An Introduction](http://incompleteideas.net/book/11/node4.html). However, a `book_or_nips` parameter can be modified to change the pendulum dynamics to those described
|
[Reinforcement Learning: An Introduction](http://incompleteideas.net/book/11/node4.html).
|
||||||
|
However, a `book_or_nips` parameter can be modified to change the pendulum dynamics to those described
|
||||||
in the original [NeurIPS paper](https://papers.nips.cc/paper/1995/hash/8f1d43620bc6bb580df6e80b0dc05c48-Abstract.html).
|
in the original [NeurIPS paper](https://papers.nips.cc/paper/1995/hash/8f1d43620bc6bb580df6e80b0dc05c48-Abstract.html).
|
||||||
|
|
||||||
```
|
```
|
||||||
@@ -125,7 +128,9 @@ class AcrobotEnv(core.Env):
|
|||||||
- v0: Initial versions release (1.0.0) (removed from gym for v1)
|
- v0: Initial versions release (1.0.0) (removed from gym for v1)
|
||||||
|
|
||||||
### References
|
### References
|
||||||
- Sutton, R. S. (1996). Generalization in Reinforcement Learning: Successful Examples Using Sparse Coarse Coding. In D. Touretzky, M. C. Mozer, & M. Hasselmo (Eds.), Advances in Neural Information Processing Systems (Vol. 8). MIT Press. https://proceedings.neurips.cc/paper/1995/file/8f1d43620bc6bb580df6e80b0dc05c48-Paper.pdf
|
- Sutton, R. S. (1996). Generalization in Reinforcement Learning: Successful Examples Using Sparse Coarse Coding.
|
||||||
|
In D. Touretzky, M. C. Mozer, & M. Hasselmo (Eds.), Advances in Neural Information Processing Systems (Vol. 8).
|
||||||
|
MIT Press. https://proceedings.neurips.cc/paper/1995/file/8f1d43620bc6bb580df6e80b0dc05c48-Paper.pdf
|
||||||
- Sutton, R. S., Barto, A. G. (2018 ). Reinforcement Learning: An Introduction. The MIT Press.
|
- Sutton, R. S., Barto, A. G. (2018 ). Reinforcement Learning: An Introduction. The MIT Press.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -380,6 +385,8 @@ def bound(x, m, M=None):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: scalar
|
x: scalar
|
||||||
|
m: The lower bound
|
||||||
|
M: The upper bound
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
x: scalar, bound between min (m) and Max (M)
|
x: scalar, bound between min (m) and Max (M)
|
||||||
@@ -398,15 +405,15 @@ def rk4(derivs, y0, t):
|
|||||||
yourself stranded on a system w/o scipy. Otherwise use
|
yourself stranded on a system w/o scipy. Otherwise use
|
||||||
:func:`scipy.integrate`.
|
:func:`scipy.integrate`.
|
||||||
|
|
||||||
Example:
|
Example for 2D system:
|
||||||
|
|
||||||
>>> ### 2D system
|
|
||||||
>>> def derivs(x):
|
>>> def derivs(x):
|
||||||
... d1 = x[0] + 2*x[1]
|
... d1 = x[0] + 2*x[1]
|
||||||
... d2 = -3*x[0] + 4*x[1]
|
... d2 = -3*x[0] + 4*x[1]
|
||||||
... return (d1, d2)
|
... return d1, d2
|
||||||
|
|
||||||
>>> dt = 0.0005
|
>>> dt = 0.0005
|
||||||
>>> t = arange(0.0, 2.0, dt)
|
>>> t = np.arange(0.0, 2.0, dt)
|
||||||
>>> y0 = (1,2)
|
>>> y0 = (1,2)
|
||||||
>>> yout = rk4(derivs, y0, t)
|
>>> yout = rk4(derivs, y0, t)
|
||||||
|
|
||||||
|
@@ -17,40 +17,47 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
|||||||
"""
|
"""
|
||||||
### Description
|
### Description
|
||||||
|
|
||||||
This environment corresponds to the version of the cart-pole problem
|
This environment corresponds to the version of the cart-pole problem described by Barto, Sutton, and Anderson in
|
||||||
described by Barto, Sutton, and Anderson in ["Neuronlike Adaptive Elements That Can Solve Difficult Learning Control Problem"](https://ieeexplore.ieee.org/document/6313077).
|
["Neuronlike Adaptive Elements That Can Solve Difficult Learning Control Problem"](https://ieeexplore.ieee.org/document/6313077).
|
||||||
A pole is attached by an un-actuated joint to a cart, which moves along a
|
A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track.
|
||||||
frictionless track. The pendulum is placed upright on the cart and the goal is to balance the pole by applying forces in the left and right direction on the cart.
|
The pendulum is placed upright on the cart and the goal is to balance the pole by applying forces
|
||||||
|
in the left and right direction on the cart.
|
||||||
|
|
||||||
### Action Space
|
### Action Space
|
||||||
|
|
||||||
The action is a `ndarray` with shape `(1,)` which can take values `{0, 1}` indicating the direction of the fixed force the cart is pushed with.
|
The action is a `ndarray` with shape `(1,)` which can take values `{0, 1}` indicating the direction
|
||||||
|
of the fixed force the cart is pushed with.
|
||||||
|
|
||||||
| Num | Action |
|
| Num | Action |
|
||||||
|-----|------------------------|
|
|-----|------------------------|
|
||||||
| 0 | Push cart to the left |
|
| 0 | Push cart to the left |
|
||||||
| 1 | Push cart to the right |
|
| 1 | Push cart to the right |
|
||||||
|
|
||||||
**Note**: The velocity that is reduced or increased by the applied force is not fixed and it depends on the angle the pole is pointing. The center of gravity of the pole varies the amount of energy needed to move the cart underneath it
|
**Note**: The velocity that is reduced or increased by the applied force is not fixed and it depends on the angle
|
||||||
|
the pole is pointing. The center of gravity of the pole varies the amount of energy needed to move the cart underneath it
|
||||||
|
|
||||||
### Observation Space
|
### Observation Space
|
||||||
|
|
||||||
The observation is a `ndarray` with shape `(4,)` with the values corresponding to the following positions and velocities:
|
The observation is a `ndarray` with shape `(4,)` with the values corresponding to the following positions and velocities:
|
||||||
|
|
||||||
| Num | Observation | Min | Max |
|
| Num | Observation | Min | Max |
|
||||||
|-----|-----------------------|----------------------|--------------------|
|
|-----|-----------------------|---------------------|-------------------|
|
||||||
| 0 | Cart Position | -4.8 | 4.8 |
|
| 0 | Cart Position | -4.8 | 4.8 |
|
||||||
| 1 | Cart Velocity | -Inf | Inf |
|
| 1 | Cart Velocity | -Inf | Inf |
|
||||||
| 2 | Pole Angle | ~ -0.418 rad (-24°) | ~ 0.418 rad (24°) |
|
| 2 | Pole Angle | ~ -0.418 rad (-24°) | ~ 0.418 rad (24°) |
|
||||||
| 3 | Pole Angular Velocity | -Inf | Inf |
|
| 3 | Pole Angular Velocity | -Inf | Inf |
|
||||||
|
|
||||||
**Note:** While the ranges above denote the possible values for observation space of each element, it is not reflective of the allowed values of the state space in an unterminated episode. Particularly:
|
**Note:** While the ranges above denote the possible values for observation space of each element,
|
||||||
- The cart x-position (index 0) can be take values between `(-4.8, 4.8)`, but the episode terminates if the cart leaves the `(-2.4, 2.4)` range.
|
it is not reflective of the allowed values of the state space in an unterminated episode. Particularly:
|
||||||
- The pole angle can be observed between `(-.418, .418)` radians (or **±24°**), but the episode terminates if the pole angle is not in the range `(-.2095, .2095)` (or **±12°**)
|
- The cart x-position (index 0) can be take values between `(-4.8, 4.8)`, but the episode terminates
|
||||||
|
if the cart leaves the `(-2.4, 2.4)` range.
|
||||||
|
- The pole angle can be observed between `(-.418, .418)` radians (or **±24°**), but the episode terminates
|
||||||
|
if the pole angle is not in the range `(-.2095, .2095)` (or **±12°**)
|
||||||
|
|
||||||
### Rewards
|
### Rewards
|
||||||
|
|
||||||
Since the goal is to keep the pole upright for as long as possible, a reward of `+1` for every step taken, including the termination step, is allotted. The threshold for rewards is 475 for v1.
|
Since the goal is to keep the pole upright for as long as possible, a reward of `+1` for every step taken,
|
||||||
|
including the termination step, is allotted. The threshold for rewards is 475 for v1.
|
||||||
|
|
||||||
### Starting State
|
### Starting State
|
||||||
|
|
||||||
|
@@ -50,13 +50,14 @@ class Continuous_MountainCarEnv(gym.Env):
|
|||||||
The observation is a `ndarray` with shape `(2,)` where the elements correspond to the following:
|
The observation is a `ndarray` with shape `(2,)` where the elements correspond to the following:
|
||||||
|
|
||||||
| Num | Observation | Min | Max | Unit |
|
| Num | Observation | Min | Max | Unit |
|
||||||
|-----|-------------------------------------------------------------|--------------------|--------|------|
|
|-----|--------------------------------------|------|-----|--------------|
|
||||||
| 0 | position of the car along the x-axis | -Inf | Inf | position (m) |
|
| 0 | position of the car along the x-axis | -Inf | Inf | position (m) |
|
||||||
| 1 | velocity of the car | -Inf | Inf | position (m) |
|
| 1 | velocity of the car | -Inf | Inf | position (m) |
|
||||||
|
|
||||||
### Action Space
|
### Action Space
|
||||||
|
|
||||||
The action is a `ndarray` with shape `(1,)`, representing the directional force applied on the car. The action is clipped in the range `[-1,1]` and multiplied by a power of 0.0015.
|
The action is a `ndarray` with shape `(1,)`, representing the directional force applied on the car.
|
||||||
|
The action is clipped in the range `[-1,1]` and multiplied by a power of 0.0015.
|
||||||
|
|
||||||
### Transition Dynamics:
|
### Transition Dynamics:
|
||||||
|
|
||||||
@@ -66,15 +67,20 @@ class Continuous_MountainCarEnv(gym.Env):
|
|||||||
|
|
||||||
*position<sub>t+1</sub> = position<sub>t</sub> + velocity<sub>t+1</sub>*
|
*position<sub>t+1</sub> = position<sub>t</sub> + velocity<sub>t+1</sub>*
|
||||||
|
|
||||||
where force is the action clipped to the range `[-1,1]` and power is a constant 0.0015. The collisions at either end are inelastic with the velocity set to 0 upon collision with the wall. The position is clipped to the range [-1.2, 0.6] and velocity is clipped to the range [-0.07, 0.07].
|
where force is the action clipped to the range `[-1,1]` and power is a constant 0.0015.
|
||||||
|
The collisions at either end are inelastic with the velocity set to 0 upon collision with the wall.
|
||||||
|
The position is clipped to the range [-1.2, 0.6] and velocity is clipped to the range [-0.07, 0.07].
|
||||||
|
|
||||||
### Reward
|
### Reward
|
||||||
|
|
||||||
A negative reward of *-0.1 * action<sup>2</sup>* is received at each timestep to penalise for taking actions of large magnitude. If the mountain car reaches the goal then a positive reward of +100 is added to the negative reward for that timestep.
|
A negative reward of *-0.1 * action<sup>2</sup>* is received at each timestep to penalise for
|
||||||
|
taking actions of large magnitude. If the mountain car reaches the goal then a positive reward of +100
|
||||||
|
is added to the negative reward for that timestep.
|
||||||
|
|
||||||
### Starting State
|
### Starting State
|
||||||
|
|
||||||
The position of the car is assigned a uniform random value in `[-0.6 , -0.4]`. The starting velocity of the car is always assigned to 0.
|
The position of the car is assigned a uniform random value in `[-0.6 , -0.4]`.
|
||||||
|
The starting velocity of the car is always assigned to 0.
|
||||||
|
|
||||||
### Episode Termination
|
### Episode Termination
|
||||||
|
|
||||||
|
@@ -39,7 +39,7 @@ class MountainCarEnv(gym.Env):
|
|||||||
The observation is a `ndarray` with shape `(2,)` where the elements correspond to the following:
|
The observation is a `ndarray` with shape `(2,)` where the elements correspond to the following:
|
||||||
|
|
||||||
| Num | Observation | Min | Max | Unit |
|
| Num | Observation | Min | Max | Unit |
|
||||||
|-----|-------------------------------------------------------------|--------------------|--------|------|
|
|-----|--------------------------------------|------|-----|--------------|
|
||||||
| 0 | position of the car along the x-axis | -Inf | Inf | position (m) |
|
| 0 | position of the car along the x-axis | -Inf | Inf | position (m) |
|
||||||
| 1 | velocity of the car | -Inf | Inf | position (m) |
|
| 1 | velocity of the car | -Inf | Inf | position (m) |
|
||||||
|
|
||||||
@@ -48,7 +48,7 @@ class MountainCarEnv(gym.Env):
|
|||||||
There are 3 discrete deterministic actions:
|
There are 3 discrete deterministic actions:
|
||||||
|
|
||||||
| Num | Observation | Value | Unit |
|
| Num | Observation | Value | Unit |
|
||||||
|-----|-------------------------------------------------------------|---------|------|
|
|-----|-------------------------|------ |--------------|
|
||||||
| 0 | Accelerate to the left | Inf | position (m) |
|
| 0 | Accelerate to the left | Inf | position (m) |
|
||||||
| 1 | Don't accelerate | Inf | position (m) |
|
| 1 | Don't accelerate | Inf | position (m) |
|
||||||
| 2 | Accelerate to the right | Inf | position (m) |
|
| 2 | Accelerate to the right | Inf | position (m) |
|
||||||
@@ -61,16 +61,21 @@ class MountainCarEnv(gym.Env):
|
|||||||
|
|
||||||
*position<sub>t+1</sub> = position<sub>t</sub> + velocity<sub>t+1</sub>*
|
*position<sub>t+1</sub> = position<sub>t</sub> + velocity<sub>t+1</sub>*
|
||||||
|
|
||||||
where force = 0.001 and gravity = 0.0025. The collisions at either end are inelastic with the velocity set to 0 upon collision with the wall. The position is clipped to the range `[-1.2, 0.6]` and velocity is clipped to the range `[-0.07, 0.07]`.
|
where force = 0.001 and gravity = 0.0025. The collisions at either end are inelastic with the velocity set to 0
|
||||||
|
upon collision with the wall. The position is clipped to the range `[-1.2, 0.6]` and
|
||||||
|
velocity is clipped to the range `[-0.07, 0.07]`.
|
||||||
|
|
||||||
|
|
||||||
### Reward:
|
### Reward:
|
||||||
|
|
||||||
The goal is to reach the flag placed on top of the right hill as quickly as possible, as such the agent is penalised with a reward of -1 for each timestep it isn't at the goal and is not penalised (reward = 0) for when it reaches the goal.
|
The goal is to reach the flag placed on top of the right hill as quickly as possible, as such the agent is
|
||||||
|
penalised with a reward of -1 for each timestep it isn't at the goal and is not penalised (reward = 0) for
|
||||||
|
when it reaches the goal.
|
||||||
|
|
||||||
### Starting State
|
### Starting State
|
||||||
|
|
||||||
The position of the car is assigned a uniform random value in *[-0.6 , -0.4]*. The starting velocity of the car is always assigned to 0.
|
The position of the car is assigned a uniform random value in *[-0.6 , -0.4]*.
|
||||||
|
The starting velocity of the car is always assigned to 0.
|
||||||
|
|
||||||
### Episode Termination
|
### Episode Termination
|
||||||
|
|
||||||
|
@@ -14,7 +14,10 @@ class PendulumEnv(gym.Env):
|
|||||||
"""
|
"""
|
||||||
### Description
|
### Description
|
||||||
|
|
||||||
The inverted pendulum swingup problem is based on the classic problem in control theory. The system consists of a pendulum attached at one end to a fixed point, and the other end being free. The pendulum starts in a random position and the goal is to apply torque on the free end to swing it into an upright position, with its center of gravity right above the fixed point.
|
The inverted pendulum swingup problem is based on the classic problem in control theory.
|
||||||
|
The system consists of a pendulum attached at one end to a fixed point, and the other end being free.
|
||||||
|
The pendulum starts in a random position and the goal is to apply torque on the free end to swing it
|
||||||
|
into an upright position, with its center of gravity right above the fixed point.
|
||||||
|
|
||||||
The diagram below specifies the coordinate system used for the implementation of the pendulum's
|
The diagram below specifies the coordinate system used for the implementation of the pendulum's
|
||||||
dynamic equations.
|
dynamic equations.
|
||||||
@@ -36,7 +39,8 @@ class PendulumEnv(gym.Env):
|
|||||||
|
|
||||||
### Observation Space
|
### Observation Space
|
||||||
|
|
||||||
The observation is a `ndarray` with shape `(3,)` representing the x-y coordinates of the pendulum's free end and its angular velocity.
|
The observation is a `ndarray` with shape `(3,)` representing the x-y coordinates of the pendulum's free
|
||||||
|
end and its angular velocity.
|
||||||
|
|
||||||
| Num | Observation | Min | Max |
|
| Num | Observation | Min | Max |
|
||||||
|-----|------------------|------|-----|
|
|-----|------------------|------|-----|
|
||||||
@@ -51,8 +55,9 @@ class PendulumEnv(gym.Env):
|
|||||||
*r = -(theta<sup>2</sup> + 0.1 * theta_dt<sup>2</sup> + 0.001 * torque<sup>2</sup>)*
|
*r = -(theta<sup>2</sup> + 0.1 * theta_dt<sup>2</sup> + 0.001 * torque<sup>2</sup>)*
|
||||||
|
|
||||||
where `$\theta$` is the pendulum's angle normalized between *[-pi, pi]* (with 0 being in the upright position).
|
where `$\theta$` is the pendulum's angle normalized between *[-pi, pi]* (with 0 being in the upright position).
|
||||||
Based on the above equation, the minimum reward that can be obtained is *-(pi<sup>2</sup> + 0.1 * 8<sup>2</sup> + 0.001 * 2<sup>2</sup>) = -16.2736044*, while the maximum reward is zero (pendulum is
|
Based on the above equation, the minimum reward that can be obtained is
|
||||||
upright with zero velocity and no torque applied).
|
*-(pi<sup>2</sup> + 0.1 * 8<sup>2</sup> + 0.001 * 2<sup>2</sup>) = -16.2736044*,
|
||||||
|
while the maximum reward is zero (pendulum is upright with zero velocity and no torque applied).
|
||||||
|
|
||||||
### Starting State
|
### Starting State
|
||||||
|
|
||||||
@@ -64,7 +69,8 @@ class PendulumEnv(gym.Env):
|
|||||||
|
|
||||||
### Arguments
|
### Arguments
|
||||||
|
|
||||||
- `g`: acceleration of gravity measured in *(m s<sup>-2</sup>)* used to calculate the pendulum dynamics. The default value is g = 10.0 .
|
- `g`: acceleration of gravity measured in *(m s<sup>-2</sup>)* used to calculate the pendulum dynamics.
|
||||||
|
The default value is g = 10.0 .
|
||||||
|
|
||||||
```
|
```
|
||||||
gym.make('Pendulum-v1', g=9.81)
|
gym.make('Pendulum-v1', g=9.81)
|
||||||
|
@@ -122,10 +122,8 @@ class MujocoEnv(gym.Env):
|
|||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
"""
|
"""
|
||||||
This method is called when the viewer is initialized.
|
This method is called when the viewer is initialized.
|
||||||
Optionally implement this method, if you need to tinker with camera position
|
Optionally implement this method, if you need to tinker with camera position and so forth.
|
||||||
and so forth.
|
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
|
|
||||||
|
@@ -15,7 +15,7 @@ class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
The action space is a `Box(-2, 2, (7,), float32)`. An action `(a, b)` represents the torques applied at the hinge joints.
|
The action space is a `Box(-2, 2, (7,), float32)`. An action `(a, b)` represents the torques applied at the hinge joints.
|
||||||
|
|
||||||
| Num | Action | Control Min | Control Max | Name (in corresponding XML file) | Joint | Unit |
|
| Num | Action | Control Min | Control Max | Name (in corresponding XML file) | Joint | Unit |
|
||||||
|-----|--------------------------------------------------------------------|-------------|-------------|----------------------------------|-------|--------------|
|
|-----|------------------------------------------------|-------------|-------------|----------------------------------|-------|--------------|
|
||||||
| 0 | Rotation of the panning the shoulder | -2 | 2 | r_shoulder_pan_joint | hinge | torque (N m) |
|
| 0 | Rotation of the panning the shoulder | -2 | 2 | r_shoulder_pan_joint | hinge | torque (N m) |
|
||||||
| 1 | Rotation of the shoulder lifting joint | -2 | 2 | r_shoulder_lift_joint | hinge | torque (N m) |
|
| 1 | Rotation of the shoulder lifting joint | -2 | 2 | r_shoulder_lift_joint | hinge | torque (N m) |
|
||||||
| 2 | Rotation of the shoulder rolling joint | -2 | 2 | r_upper_arm_roll_joint | hinge | torque (N m) |
|
| 2 | Rotation of the shoulder rolling joint | -2 | 2 | r_upper_arm_roll_joint | hinge | torque (N m) |
|
||||||
|
@@ -31,7 +31,7 @@ class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
The observation is a `ndarray` with shape `(11,)` where the elements correspond to the following:
|
The observation is a `ndarray` with shape `(11,)` where the elements correspond to the following:
|
||||||
|
|
||||||
| Num | Observation | Min | Max | Name (in corresponding XML file) | Joint | Unit |
|
| Num | Observation | Min | Max | Name (in corresponding XML file) | Joint | Unit |
|
||||||
|-----|-------------------------------------------------------------------------------------------------|------|-----|----------------------------------|-------|--------------------------|
|
|-----|------------------------------------------------------------------------------------------------|------|-----|----------------------------------|-------|--------------------------|
|
||||||
| 0 | cosine of the angle of the first arm | -Inf | Inf | cos(joint0) | hinge | unitless |
|
| 0 | cosine of the angle of the first arm | -Inf | Inf | cos(joint0) | hinge | unitless |
|
||||||
| 1 | cosine of the angle of the second arm | -Inf | Inf | cos(joint1) | hinge | unitless |
|
| 1 | cosine of the angle of the second arm | -Inf | Inf | cos(joint1) | hinge | unitless |
|
||||||
| 2 | sine of the angle of the first arm | -Inf | Inf | cos(joint0) | hinge | unitless |
|
| 2 | sine of the angle of the first arm | -Inf | Inf | cos(joint0) | hinge | unitless |
|
||||||
|
@@ -16,8 +16,6 @@ from typing import (
|
|||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
SupportsFloat,
|
SupportsFloat,
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
Union,
|
Union,
|
||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
@@ -49,14 +47,14 @@ ENV_ID_RE: re.Pattern = re.compile(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load(name: str) -> Type:
|
def load(name: str) -> type:
|
||||||
mod_name, attr_name = name.split(":")
|
mod_name, attr_name = name.split(":")
|
||||||
mod = importlib.import_module(mod_name)
|
mod = importlib.import_module(mod_name)
|
||||||
fn = getattr(mod, attr_name)
|
fn = getattr(mod, attr_name)
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
def parse_env_id(id: str) -> Tuple[Optional[str], str, Optional[int]]:
|
def parse_env_id(id: str) -> tuple[Optional[str], str, Optional[int]]:
|
||||||
"""Parse environment ID string format.
|
"""Parse environment ID string format.
|
||||||
|
|
||||||
This format is true today, but it's *not* an official spec.
|
This format is true today, but it's *not* an official spec.
|
||||||
@@ -64,6 +62,15 @@ def parse_env_id(id: str) -> Tuple[Optional[str], str, Optional[int]]:
|
|||||||
|
|
||||||
2016-10-31: We're experimentally expanding the environment ID format
|
2016-10-31: We're experimentally expanding the environment ID format
|
||||||
to include an optional namespace.
|
to include an optional namespace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id: The environment id to parse
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of environment namespace, environment name and version number
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Error: If the environment id does not a valid environment regex
|
||||||
"""
|
"""
|
||||||
match = ENV_ID_RE.fullmatch(id)
|
match = ENV_ID_RE.fullmatch(id)
|
||||||
if not match:
|
if not match:
|
||||||
@@ -78,9 +85,17 @@ def parse_env_id(id: str) -> Tuple[Optional[str], str, Optional[int]]:
|
|||||||
return namespace, name, version
|
return namespace, name, version
|
||||||
|
|
||||||
|
|
||||||
def get_env_id(ns: Optional[str], name: str, version: Optional[int]):
|
def get_env_id(ns: Optional[str], name: str, version: Optional[int]) -> str:
|
||||||
"""Get the full env ID given a name and (optional) version and namespace.
|
"""Get the full env ID given a name and (optional) version and namespace. Inverse of :meth:`parse_env_id`.
|
||||||
Inverse of parse_env_id."""
|
|
||||||
|
Args:
|
||||||
|
ns: The environment namespace
|
||||||
|
name: The environment name
|
||||||
|
version: The environment version
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The environment id
|
||||||
|
"""
|
||||||
|
|
||||||
full_name = name
|
full_name = name
|
||||||
if version is not None:
|
if version is not None:
|
||||||
@@ -172,7 +187,18 @@ def _check_name_exists(ns: Optional[str], name: str):
|
|||||||
|
|
||||||
def _check_version_exists(ns: Optional[str], name: str, version: Optional[int]):
|
def _check_version_exists(ns: Optional[str], name: str, version: Optional[int]):
|
||||||
"""Check if an env version exists in a namespace. If it doesn't, print a helpful error message.
|
"""Check if an env version exists in a namespace. If it doesn't, print a helpful error message.
|
||||||
This is a complete test whether an environment identifier is valid, and will provide the best available hints."""
|
This is a complete test whether an environment identifier is valid, and will provide the best available hints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ns: The environment namespace
|
||||||
|
name: The environment space
|
||||||
|
version: The environment version
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DeprecatedEnv: The environment doesn't exist but a default version does
|
||||||
|
VersionNotFound: The ``version`` used doesn't exist
|
||||||
|
DeprecatedEnv: Environment version is deprecated
|
||||||
|
"""
|
||||||
if get_env_id(ns, name, version) in registry:
|
if get_env_id(ns, name, version) in registry:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -344,6 +370,7 @@ class EnvRegistry(dict):
|
|||||||
Turns out that some existing code directly used the old `EnvRegistry` code,
|
Turns out that some existing code directly used the old `EnvRegistry` code,
|
||||||
even though the intended API was just `register` and `make`.
|
even though the intended API was just `register` and `make`.
|
||||||
This reimplements some old methods, so that e.g. pybullet environments will still work.
|
This reimplements some old methods, so that e.g. pybullet environments will still work.
|
||||||
|
|
||||||
Ideally, nobody should ever use these methods, and they will be removed soon.
|
Ideally, nobody should ever use these methods, and they will be removed soon.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -458,13 +485,16 @@ def namespace(ns: str):
|
|||||||
|
|
||||||
|
|
||||||
def register(id: str, **kwargs):
|
def register(id: str, **kwargs):
|
||||||
"""
|
"""Register an environment with gym.
|
||||||
Register an environment with gym. The `id` parameter corresponds to the name of the environment,
|
|
||||||
with the syntax as follows:
|
The `id` parameter corresponds to the name of the environment, with the syntax as follows:
|
||||||
`(namespace)/(env_name)-v(version)`
|
`(namespace)/(env_name)-v(version)` where `namespace` is optional.
|
||||||
where `namespace` is optional.
|
|
||||||
|
|
||||||
It takes arbitrary keyword arguments, which are passed to the `EnvSpec` constructor.
|
It takes arbitrary keyword arguments, which are passed to the `EnvSpec` constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id: The environment id
|
||||||
|
**kwargs: arbitrary keyword arguments which are passed to the environment constructor
|
||||||
"""
|
"""
|
||||||
global registry, current_namespace
|
global registry, current_namespace
|
||||||
ns, name, version = parse_env_id(id)
|
ns, name, version = parse_env_id(id)
|
||||||
@@ -498,8 +528,7 @@ def make(
|
|||||||
disable_env_checker: bool = False,
|
disable_env_checker: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Env:
|
) -> Env:
|
||||||
"""
|
"""Create an environment according to the given ID.
|
||||||
Create an environment according to the given ID.
|
|
||||||
|
|
||||||
Warnings:
|
Warnings:
|
||||||
In v0.24, `gym.utils.env_checker.env_checker` is run for every initialised environment.
|
In v0.24, `gym.utils.env_checker.env_checker` is run for every initialised environment.
|
||||||
@@ -512,8 +541,12 @@ def make(
|
|||||||
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
|
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
|
||||||
disable_env_checker: If to disable the environment checker
|
disable_env_checker: If to disable the environment checker
|
||||||
kwargs: Additional arguments to pass to the environment constructor.
|
kwargs: Additional arguments to pass to the environment constructor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An instance of the environment.
|
An instance of the environment.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Error: If the ``id`` doesn't exist then an error is raised
|
||||||
"""
|
"""
|
||||||
if isinstance(id, EnvSpec):
|
if isinstance(id, EnvSpec):
|
||||||
spec_ = id
|
spec_ = id
|
||||||
@@ -588,7 +621,8 @@ def make(
|
|||||||
check_env(env)
|
check_env(env)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
f"Env check failed with the following message: {e}\nYou can call `gym.make(..., disable_env_checker=True)` to disable this check."
|
f"Env check failed with the following message: {e}\n"
|
||||||
|
f"You can set `disable_env_checker=True` to disable this check."
|
||||||
)
|
)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
@@ -42,8 +42,10 @@ class CliffWalkingEnv(Env):
|
|||||||
- 3: move left
|
- 3: move left
|
||||||
|
|
||||||
### Observations
|
### Observations
|
||||||
There are 3x12 + 1 possible states. In fact, the agent cannot be at the cliff, nor at the goal (as this results the end of episode). They remain all the positions of the first 3 rows plus the bottom-left cell.
|
There are 3x12 + 1 possible states. In fact, the agent cannot be at the cliff, nor at the goal
|
||||||
The observation is simply the current position encoded as [flattened index](https://numpy.org/doc/stable/reference/generated/numpy.unravel_index.html).
|
(as this results the end of episode). They remain all the positions of the first 3 rows plus the bottom-left cell.
|
||||||
|
The observation is simply the current position encoded as
|
||||||
|
[flattened index](https://numpy.org/doc/stable/reference/generated/numpy.unravel_index.html).
|
||||||
|
|
||||||
### Reward
|
### Reward
|
||||||
Each time step incurs -1 reward, and stepping into the cliff incurs -100 reward.
|
Each time step incurs -1 reward, and stepping into the cliff incurs -100 reward.
|
||||||
@@ -89,12 +91,8 @@ class CliffWalkingEnv(Env):
|
|||||||
self.observation_space = spaces.Discrete(self.nS)
|
self.observation_space = spaces.Discrete(self.nS)
|
||||||
self.action_space = spaces.Discrete(self.nA)
|
self.action_space = spaces.Discrete(self.nA)
|
||||||
|
|
||||||
def _limit_coordinates(self, coord):
|
def _limit_coordinates(self, coord: np.ndarray) -> np.ndarray:
|
||||||
"""
|
"""Prevent the agent from falling out of the grid world."""
|
||||||
Prevent the agent from falling out of the grid world
|
|
||||||
:param coord:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
coord[0] = min(coord[0], self.shape[0] - 1)
|
coord[0] = min(coord[0], self.shape[0] - 1)
|
||||||
coord[0] = max(coord[0], 0)
|
coord[0] = max(coord[0], 0)
|
||||||
coord[1] = min(coord[1], self.shape[1] - 1)
|
coord[1] = min(coord[1], self.shape[1] - 1)
|
||||||
@@ -102,11 +100,14 @@ class CliffWalkingEnv(Env):
|
|||||||
return coord
|
return coord
|
||||||
|
|
||||||
def _calculate_transition_prob(self, current, delta):
|
def _calculate_transition_prob(self, current, delta):
|
||||||
"""
|
"""Determine the outcome for an action. Transition Prob is always 1.0.
|
||||||
Determine the outcome for an action. Transition Prob is always 1.0.
|
|
||||||
:param current: Current position on the grid as (row, col)
|
Args:
|
||||||
:param delta: Change in position for transition
|
current: Current position on the grid as (row, col)
|
||||||
:return: (1.0, new_state, reward, done)
|
delta: Change in position for transition
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of ``(1.0, new_state, reward, done)``
|
||||||
"""
|
"""
|
||||||
new_position = np.array(current) + np.array(delta)
|
new_position = np.array(current) + np.array(delta)
|
||||||
new_position = self._limit_coordinates(new_position).astype(int)
|
new_position = self._limit_coordinates(new_position).astype(int)
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from os import path
|
from os import path
|
||||||
@@ -29,10 +31,15 @@ MAPS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def generate_random_map(size=8, p=0.8):
|
def generate_random_map(size: int = 8, p: float = 0.8) -> list[str]:
|
||||||
"""Generates a random valid map (one that has a path from start to goal)
|
"""Generates a random valid map (one that has a path from start to goal)
|
||||||
:param size: size of each side of the grid
|
|
||||||
:param p: probability that a tile is frozen
|
Args:
|
||||||
|
size: size of each side of the grid
|
||||||
|
p: probability that a tile is frozen
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A random valid map
|
||||||
"""
|
"""
|
||||||
valid = False
|
valid = False
|
||||||
|
|
||||||
@@ -67,8 +74,9 @@ def generate_random_map(size=8, p=0.8):
|
|||||||
|
|
||||||
class FrozenLakeEnv(Env):
|
class FrozenLakeEnv(Env):
|
||||||
"""
|
"""
|
||||||
Frozen lake involves crossing a frozen lake from Start(S) to Goal(G) without falling into any Holes(H) by walking over
|
Frozen lake involves crossing a frozen lake from Start(S) to Goal(G) without falling into any Holes(H)
|
||||||
the Frozen(F) lake. The agent may not always move in the intended direction due to the slippery nature of the frozen lake.
|
by walking over the Frozen(F) lake.
|
||||||
|
The agent may not always move in the intended direction due to the slippery nature of the frozen lake.
|
||||||
|
|
||||||
|
|
||||||
### Action Space
|
### Action Space
|
||||||
|
@@ -1,11 +1,10 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from gym.utils import seeding
|
||||||
|
|
||||||
def categorical_sample(prob_n, np_random):
|
|
||||||
"""
|
def categorical_sample(prob_n, np_random: seeding.RandomNumberGenerator):
|
||||||
Sample from categorical distribution
|
"""Sample from categorical distribution where each row specifies class probabilities."""
|
||||||
Each row specifies class probabilities
|
|
||||||
"""
|
|
||||||
prob_n = np.asarray(prob_n)
|
prob_n = np.asarray(prob_n)
|
||||||
csprob_n = np.cumsum(prob_n)
|
csprob_n = np.cumsum(prob_n)
|
||||||
return (csprob_n > np_random.random()).argmax()
|
return np.argmax(csprob_n > np_random.random())
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
"""Implementation of a space that represents closed boxes in euclidean space."""
|
"""Implementation of a space that represents closed boxes in euclidean space."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Optional, Sequence, SupportsFloat, Tuple, Type, Union
|
from typing import Optional, Sequence, SupportsFloat, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -15,6 +15,12 @@ def _short_repr(arr: np.ndarray) -> str:
|
|||||||
|
|
||||||
If arr is a multiple of the all-ones vector, return a string representation of the multiplier.
|
If arr is a multiple of the all-ones vector, return a string representation of the multiplier.
|
||||||
Otherwise, return a string representation of the entire array.
|
Otherwise, return a string representation of the entire array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arr: The array to represent
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A short representation of the array
|
||||||
"""
|
"""
|
||||||
if arr.size != 0 and np.min(arr) == np.max(arr):
|
if arr.size != 0 and np.min(arr) == np.max(arr):
|
||||||
return str(np.min(arr))
|
return str(np.min(arr))
|
||||||
@@ -46,7 +52,7 @@ class Box(Space[np.ndarray]):
|
|||||||
low: Union[SupportsFloat, np.ndarray],
|
low: Union[SupportsFloat, np.ndarray],
|
||||||
high: Union[SupportsFloat, np.ndarray],
|
high: Union[SupportsFloat, np.ndarray],
|
||||||
shape: Optional[Sequence[int]] = None,
|
shape: Optional[Sequence[int]] = None,
|
||||||
dtype: Type = np.float32,
|
dtype: type = np.float32,
|
||||||
seed: Optional[int | seeding.RandomNumberGenerator] = None,
|
seed: Optional[int | seeding.RandomNumberGenerator] = None,
|
||||||
):
|
):
|
||||||
r"""Constructor of :class:`Box`.
|
r"""Constructor of :class:`Box`.
|
||||||
@@ -57,7 +63,6 @@ class Box(Space[np.ndarray]):
|
|||||||
If ``low`` (or ``high``) is a scalar, the lower bound (or upper bound, respectively) will be assumed to be
|
If ``low`` (or ``high``) is a scalar, the lower bound (or upper bound, respectively) will be assumed to be
|
||||||
this value across all dimensions.
|
this value across all dimensions.
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
low (Union[SupportsFloat, np.ndarray]): Lower bounds of the intervals.
|
low (Union[SupportsFloat, np.ndarray]): Lower bounds of the intervals.
|
||||||
high (Union[SupportsFloat, np.ndarray]): Upper bounds of the intervals.
|
high (Union[SupportsFloat, np.ndarray]): Upper bounds of the intervals.
|
||||||
@@ -65,6 +70,10 @@ class Box(Space[np.ndarray]):
|
|||||||
Otherwise, the shape is inferred from the shape of ``low`` or ``high``.
|
Otherwise, the shape is inferred from the shape of ``low`` or ``high``.
|
||||||
dtype: The dtype of the elements of the space. If this is an integer type, the :class:`Box` is essentially a discrete space.
|
dtype: The dtype of the elements of the space. If this is an integer type, the :class:`Box` is essentially a discrete space.
|
||||||
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
|
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no shape information is provided (shape is None, low is None and high is None) then a
|
||||||
|
value error is raised.
|
||||||
"""
|
"""
|
||||||
assert dtype is not None, "dtype must be explicitly provided. "
|
assert dtype is not None, "dtype must be explicitly provided. "
|
||||||
self.dtype = np.dtype(dtype)
|
self.dtype = np.dtype(dtype)
|
||||||
@@ -96,7 +105,7 @@ class Box(Space[np.ndarray]):
|
|||||||
assert isinstance(high, np.ndarray)
|
assert isinstance(high, np.ndarray)
|
||||||
assert high.shape == shape, "high.shape doesn't match provided shape"
|
assert high.shape == shape, "high.shape doesn't match provided shape"
|
||||||
|
|
||||||
self._shape: Tuple[int, ...] = shape
|
self._shape: tuple[int, ...] = shape
|
||||||
|
|
||||||
low_precision = get_precision(low.dtype)
|
low_precision = get_precision(low.dtype)
|
||||||
high_precision = get_precision(high.dtype)
|
high_precision = get_precision(high.dtype)
|
||||||
@@ -112,7 +121,7 @@ class Box(Space[np.ndarray]):
|
|||||||
super().__init__(self.shape, self.dtype, seed)
|
super().__init__(self.shape, self.dtype, seed)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self) -> Tuple[int, ...]:
|
def shape(self) -> tuple[int, ...]:
|
||||||
"""Has stricter type than gym.Space - never None."""
|
"""Has stricter type than gym.Space - never None."""
|
||||||
return self._shape
|
return self._shape
|
||||||
|
|
||||||
@@ -122,6 +131,9 @@ class Box(Space[np.ndarray]):
|
|||||||
Args:
|
Args:
|
||||||
manner (str): One of ``"both"``, ``"below"``, ``"above"``.
|
manner (str): One of ``"both"``, ``"below"``, ``"above"``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If the space is bounded
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If `manner` is neither ``"both"`` nor ``"below"`` or ``"above"``
|
ValueError: If `manner` is neither ``"both"`` nor ``"below"`` or ``"above"``
|
||||||
"""
|
"""
|
||||||
@@ -146,6 +158,9 @@ class Box(Space[np.ndarray]):
|
|||||||
* :math:`[a, \infty)` : shifted exponential distribution
|
* :math:`[a, \infty)` : shifted exponential distribution
|
||||||
* :math:`(-\infty, b]` : shifted negative exponential distribution
|
* :math:`(-\infty, b]` : shifted negative exponential distribution
|
||||||
* :math:`(-\infty, \infty)` : normal distribution
|
* :math:`(-\infty, \infty)` : normal distribution
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A sampled value from the Box
|
||||||
"""
|
"""
|
||||||
high = self.high if self.dtype.kind == "f" else self.high.astype("int64") + 1
|
high = self.high if self.dtype.kind == "f" else self.high.astype("int64") + 1
|
||||||
sample = np.empty(self.shape)
|
sample = np.empty(self.shape)
|
||||||
@@ -204,6 +219,9 @@ class Box(Space[np.ndarray]):
|
|||||||
|
|
||||||
The representation will include bounds, shape and dtype.
|
The representation will include bounds, shape and dtype.
|
||||||
If a bound is uniform, only the corresponding scalar will be given to avoid redundant and ugly strings.
|
If a bound is uniform, only the corresponding scalar will be given to avoid redundant and ugly strings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A representation of the space
|
||||||
"""
|
"""
|
||||||
return f"Box({self.low_repr}, {self.high_repr}, {self.shape}, {self.dtype})"
|
return f"Box({self.low_repr}, {self.high_repr}, {self.shape}, {self.dtype})"
|
||||||
|
|
||||||
@@ -223,6 +241,13 @@ def get_inf(dtype, sign: str) -> SupportsFloat:
|
|||||||
Args:
|
Args:
|
||||||
dtype: An `np.dtype`
|
dtype: An `np.dtype`
|
||||||
sign (str): must be either `"+"` or `"-"`
|
sign (str): must be either `"+"` or `"-"`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Gets an infinite value with the sign and dtype
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: Unknown sign, use either '+' or '-'
|
||||||
|
ValueError: Unknown dtype for infinite bounds
|
||||||
"""
|
"""
|
||||||
if np.dtype(dtype).kind == "f":
|
if np.dtype(dtype).kind == "f":
|
||||||
if sign == "+":
|
if sign == "+":
|
||||||
|
@@ -143,6 +143,9 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
"""Generates a single random sample from this space.
|
"""Generates a single random sample from this space.
|
||||||
|
|
||||||
The sample is an ordered dictionary of independent samples from the constituent spaces.
|
The sample is an ordered dictionary of independent samples from the constituent spaces.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary with the same key and sampled values from :attr:`self.spaces`
|
||||||
"""
|
"""
|
||||||
return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])
|
return OrderedDict([(k, space.sample()) for k, space in self.spaces.items()])
|
||||||
|
|
||||||
@@ -157,11 +160,11 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key: str) -> Space:
|
||||||
"""Get the space that is associated to `key`."""
|
"""Get the space that is associated to `key`."""
|
||||||
return self.spaces[key]
|
return self.spaces[key]
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key: str, value: Space):
|
||||||
"""Set the space that is associated to `key`."""
|
"""Set the space that is associated to `key`."""
|
||||||
self.spaces[key] = value
|
self.spaces[key] = value
|
||||||
|
|
||||||
@@ -175,11 +178,7 @@ class Dict(Space[TypingDict[str, Space]], Mapping):
|
|||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""Gives a string representation of this space."""
|
"""Gives a string representation of this space."""
|
||||||
return (
|
return "Dict(" + ", ".join([f"{k}: {s}" for k, s in self.spaces.items()]) + ")"
|
||||||
"Dict("
|
|
||||||
+ ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()])
|
|
||||||
+ ")"
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_jsonable(self, sample_n: list) -> dict:
|
def to_jsonable(self, sample_n: list) -> dict:
|
||||||
"""Convert a batch of samples from this space to a JSONable data type."""
|
"""Convert a batch of samples from this space to a JSONable data type."""
|
||||||
|
@@ -45,6 +45,9 @@ class Discrete(Space[int]):
|
|||||||
"""Generates a single random sample from this space.
|
"""Generates a single random sample from this space.
|
||||||
|
|
||||||
A sample will be chosen uniformly at random.
|
A sample will be chosen uniformly at random.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A sampled integer from the space
|
||||||
"""
|
"""
|
||||||
return int(self.start + self.np_random.integers(self.n))
|
return int(self.start + self.np_random.integers(self.n))
|
||||||
|
|
||||||
@@ -78,6 +81,9 @@ class Discrete(Space[int]):
|
|||||||
"""Used when loading a pickled space.
|
"""Used when loading a pickled space.
|
||||||
|
|
||||||
This method has to be implemented explicitly to allow for loading of legacy states.
|
This method has to be implemented explicitly to allow for loading of legacy states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The new state
|
||||||
"""
|
"""
|
||||||
super().__setstate__(state)
|
super().__setstate__(state)
|
||||||
|
|
||||||
|
@@ -57,6 +57,9 @@ class MultiBinary(Space[np.ndarray]):
|
|||||||
"""Generates a single random sample from this space.
|
"""Generates a single random sample from this space.
|
||||||
|
|
||||||
A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
|
A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sampled values from space
|
||||||
"""
|
"""
|
||||||
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
|
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
|
||||||
|
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
"""Implementation of the `Space` metaclass."""
|
"""Implementation of the `Space` metaclass."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Generic, Iterable, Mapping, Optional, Sequence, Type, TypeVar
|
from typing import Generic, Iterable, Mapping, Optional, Sequence, TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -16,8 +16,10 @@ class Space(Generic[T_cov]):
|
|||||||
Spaces are crucially used in Gym to define the format of valid actions and observations.
|
Spaces are crucially used in Gym to define the format of valid actions and observations.
|
||||||
They serve various purposes:
|
They serve various purposes:
|
||||||
|
|
||||||
* They clearly define how to interact with environments, i.e. they specify what actions need to look like and what observations will look like
|
* They clearly define how to interact with environments, i.e. they specify what actions need to look like
|
||||||
* They allow us to work with highly structured data (e.g. in the form of elements of :class:`Dict` spaces) and painlessly transform them into flat arrays that can be used in learning code
|
and what observations will look like
|
||||||
|
* They allow us to work with highly structured data (e.g. in the form of elements of :class:`Dict` spaces)
|
||||||
|
and painlessly transform them into flat arrays that can be used in learning code
|
||||||
* They provide a method to sample random elements. This is especially useful for exploration and debugging.
|
* They provide a method to sample random elements. This is especially useful for exploration and debugging.
|
||||||
|
|
||||||
Different spaces can be combined hierarchically via container spaces (:class:`Tuple` and :class:`Dict`) to build a
|
Different spaces can be combined hierarchically via container spaces (:class:`Tuple` and :class:`Dict`) to build a
|
||||||
@@ -37,7 +39,7 @@ class Space(Generic[T_cov]):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
shape: Optional[Sequence[int]] = None,
|
shape: Optional[Sequence[int]] = None,
|
||||||
dtype: Optional[Type | str] = None,
|
dtype: Optional[type | str | np.dtype] = None,
|
||||||
seed: Optional[int | seeding.RandomNumberGenerator] = None,
|
seed: Optional[int | seeding.RandomNumberGenerator] = None,
|
||||||
):
|
):
|
||||||
"""Constructor of :class:`Space`.
|
"""Constructor of :class:`Space`.
|
||||||
@@ -90,6 +92,9 @@ class Space(Generic[T_cov]):
|
|||||||
"""Used when loading a pickled space.
|
"""Used when loading a pickled space.
|
||||||
|
|
||||||
This method was implemented explicitly to allow for loading of legacy states.
|
This method was implemented explicitly to allow for loading of legacy states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The updated state value
|
||||||
"""
|
"""
|
||||||
# Don't mutate the original state
|
# Don't mutate the original state
|
||||||
state = dict(state)
|
state = dict(state)
|
||||||
|
@@ -79,6 +79,9 @@ class Tuple(Space[tuple], Sequence):
|
|||||||
"""Generates a single random sample inside this space.
|
"""Generates a single random sample inside this space.
|
||||||
|
|
||||||
This method draws independent samples from the subspaces.
|
This method draws independent samples from the subspaces.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of the subspace's samples
|
||||||
"""
|
"""
|
||||||
return tuple(space.sample() for space in self.spaces)
|
return tuple(space.sample() for space in self.spaces)
|
||||||
|
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
"""Implementation of utility functions that can be applied to spaces.
|
"""Implementation of utility functions that can be applied to spaces.
|
||||||
|
|
||||||
These functions mostly take care of flattening and unflattening elements of spaces to facilitate their usage in learning code.
|
These functions mostly take care of flattening and unflattening elements of spaces
|
||||||
|
to facilitate their usage in learning code.
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -18,17 +19,21 @@ from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Space, T
|
|||||||
def flatdim(space: Space) -> int:
|
def flatdim(space: Space) -> int:
|
||||||
"""Return the number of dimensions a flattened equivalent of this space would have.
|
"""Return the number of dimensions a flattened equivalent of this space would have.
|
||||||
|
|
||||||
Accepts a space and returns an integer.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotImplementedError: if the space is not defined in ``gym.spaces``.
|
|
||||||
|
|
||||||
Example usage::
|
Example usage::
|
||||||
|
|
||||||
>>> from gym.spaces import Discrete
|
>>> from gym.spaces import Discrete
|
||||||
>>> space = Dict({"position": Discrete(2), "velocity": Discrete(3)})
|
>>> space = Dict({"position": Discrete(2), "velocity": Discrete(3)})
|
||||||
>>> flatdim(space)
|
>>> flatdim(space)
|
||||||
5
|
5
|
||||||
|
|
||||||
|
Args:
|
||||||
|
space: The space to return the number of dimensions of the flattened spaces
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The number of dimensions for the flattened spaces
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: if the space is not defined in ``gym.spaces``.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError(f"Unknown space: `{space}`")
|
raise NotImplementedError(f"Unknown space: `{space}`")
|
||||||
|
|
||||||
@@ -69,9 +74,15 @@ def flatten(space: Space[T], x: T) -> np.ndarray:
|
|||||||
This is useful when e.g. points from spaces must be passed to a neural
|
This is useful when e.g. points from spaces must be passed to a neural
|
||||||
network, which only understands flat arrays of floats.
|
network, which only understands flat arrays of floats.
|
||||||
|
|
||||||
Accepts a space and a point from that space. Always returns a 1D array.
|
Args:
|
||||||
Raises ``NotImplementedError`` if the space is not defined in
|
space: The space that ``x`` is flattened by
|
||||||
``gym.spaces``.
|
x: The value to flatten
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The flattened ``x``, always returns a 1D array.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: If the space is not defined in ``gym.spaces``.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError(f"Unknown space: `{space}`")
|
raise NotImplementedError(f"Unknown space: `{space}`")
|
||||||
|
|
||||||
@@ -116,9 +127,15 @@ def unflatten(space: Space[T], x: np.ndarray) -> T:
|
|||||||
This reverses the transformation applied by :func:`flatten`. You must ensure
|
This reverses the transformation applied by :func:`flatten`. You must ensure
|
||||||
that the ``space`` argument is the same as for the :func:`flatten` call.
|
that the ``space`` argument is the same as for the :func:`flatten` call.
|
||||||
|
|
||||||
Accepts a space and a flattened point. Returns a point with a structure
|
Args:
|
||||||
that matches the space. Raises ``NotImplementedError`` if the space is not
|
space: The space used to unflatten ``x``
|
||||||
defined in ``gym.spaces``.
|
x: The array to unflatten
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A point with a structure that matches the space.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: if the space is not defined in ``gym.spaces``.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError(f"Unknown space: `{space}`")
|
raise NotImplementedError(f"Unknown space: `{space}`")
|
||||||
|
|
||||||
@@ -173,9 +190,6 @@ def flatten_space(space: Space) -> Box:
|
|||||||
:func:`flatdim` dimensions. Flattening a sample of the original space
|
:func:`flatdim` dimensions. Flattening a sample of the original space
|
||||||
has the same effect as taking a sample of the flattenend space.
|
has the same effect as taking a sample of the flattenend space.
|
||||||
|
|
||||||
Raises ``NotImplementedError`` if the space is not defined in
|
|
||||||
``gym.spaces``.
|
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
>>> box = Box(0.0, 1.0, shape=(3, 4, 5))
|
>>> box = Box(0.0, 1.0, shape=(3, 4, 5))
|
||||||
@@ -201,6 +215,15 @@ def flatten_space(space: Space) -> Box:
|
|||||||
Box(6,)
|
Box(6,)
|
||||||
>>> flatten(space, space.sample()) in flatten_space(space)
|
>>> flatten(space, space.sample()) in flatten_space(space)
|
||||||
True
|
True
|
||||||
|
|
||||||
|
Args:
|
||||||
|
space: The space to flatten
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A flattened Box
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: if the space is not defined in ``gym.spaces``.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError(f"Unknown space: `{space}`")
|
raise NotImplementedError(f"Unknown space: `{space}`")
|
||||||
|
|
||||||
|
@@ -210,6 +210,10 @@ def _check_returned_values(env: gym.Env, observation_space: Space, action_space:
|
|||||||
env: The environment
|
env: The environment
|
||||||
observation_space: The environment's observation space
|
observation_space: The environment's observation space
|
||||||
action_space: The environment's action space
|
action_space: The environment's action space
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the ``observation_space`` is :class:`Dict` and
|
||||||
|
keys from :meth:`Env.reset` are not in the observation space
|
||||||
"""
|
"""
|
||||||
# because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists
|
# because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
@@ -329,6 +333,10 @@ def _check_reset_seed(env: gym.Env, seed: Optional[int] = None):
|
|||||||
Args:
|
Args:
|
||||||
env: The environment to check
|
env: The environment to check
|
||||||
seed: The optional seed to use
|
seed: The optional seed to use
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: The environment cannot be reset with a random seed,
|
||||||
|
even though `seed` or `kwargs` appear in the signature.
|
||||||
"""
|
"""
|
||||||
signature = inspect.signature(env.reset)
|
signature = inspect.signature(env.reset)
|
||||||
assert (
|
assert (
|
||||||
@@ -365,6 +373,10 @@ def _check_reset_info(env: gym.Env):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
env: The environment to check
|
env: The environment to check
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: The environment cannot be reset with `return_info=True`,
|
||||||
|
even though `return_info` or `kwargs` appear in the signature.
|
||||||
"""
|
"""
|
||||||
signature = inspect.signature(env.reset)
|
signature = inspect.signature(env.reset)
|
||||||
assert (
|
assert (
|
||||||
@@ -394,6 +406,10 @@ def _check_reset_options(env: gym.Env):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
env: The environment to check
|
env: The environment to check
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: The environment cannot be reset with options,
|
||||||
|
even though `options` or `kwargs` appear in the signature.
|
||||||
"""
|
"""
|
||||||
signature = inspect.signature(env.reset)
|
signature = inspect.signature(env.reset)
|
||||||
assert (
|
assert (
|
||||||
|
@@ -62,8 +62,8 @@ class PlayableGame:
|
|||||||
keys_to_action = self.env.unwrapped.get_keys_to_action()
|
keys_to_action = self.env.unwrapped.get_keys_to_action()
|
||||||
else:
|
else:
|
||||||
raise MissingKeysToAction(
|
raise MissingKeysToAction(
|
||||||
"%s does not have explicit key to action mapping, "
|
f"{self.env.spec.id} does not have explicit key to action mapping, "
|
||||||
"please specify one manually" % self.env.spec.id
|
"please specify one manually"
|
||||||
)
|
)
|
||||||
relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), []))
|
relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), []))
|
||||||
return relevant_keys
|
return relevant_keys
|
||||||
@@ -81,7 +81,8 @@ class PlayableGame:
|
|||||||
def process_event(self, event: Event):
|
def process_event(self, event: Event):
|
||||||
"""Processes a PyGame event.
|
"""Processes a PyGame event.
|
||||||
|
|
||||||
In particular, this function is used to keep track of which buttons are currently pressed and to exit the :func:`play` function when the PyGame window is closed.
|
In particular, this function is used to keep track of which buttons are currently pressed
|
||||||
|
and to exit the :func:`play` function when the PyGame window is closed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event: The event to process
|
event: The event to process
|
||||||
@@ -258,15 +259,16 @@ class PlayPlot:
|
|||||||
It should return a list of metrics that are computed from this data.
|
It should return a list of metrics that are computed from this data.
|
||||||
For instance, the function may look like this::
|
For instance, the function may look like this::
|
||||||
|
|
||||||
def compute_metrics(obs_t, obs_tp, action, reward, done, info):
|
>>> def compute_metrics(obs_t, obs_tp, action, reward, done, info):
|
||||||
return [reward, info["cumulative_reward"], np.linalg.norm(action)]
|
... return [reward, info["cumulative_reward"], np.linalg.norm(action)]
|
||||||
|
|
||||||
:class:`PlayPlot` provides the method :meth:`callback` which will pass its arguments along to that function
|
:class:`PlayPlot` provides the method :meth:`callback` which will pass its arguments along to that function
|
||||||
and uses the returned values to update live plots of the metrics.
|
and uses the returned values to update live plots of the metrics.
|
||||||
|
|
||||||
Typically, this :meth:`callback` will be used in conjunction with :func:`play` to see how the metrics evolve as you play::
|
Typically, this :meth:`callback` will be used in conjunction with :func:`play` to see how the metrics evolve as you play::
|
||||||
|
|
||||||
>>> plotter = PlayPlot(compute_metrics, horizon_timesteps=200, plot_names=["Immediate Rew.", "Cumulative Rew.", "Action Magnitude"])
|
>>> plotter = PlayPlot(compute_metrics, horizon_timesteps=200,
|
||||||
|
... plot_names=["Immediate Rew.", "Cumulative Rew.", "Action Magnitude"])
|
||||||
>>> play(your_env, callback=plotter.callback)
|
>>> play(your_env, callback=plotter.callback)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -282,6 +284,9 @@ class PlayPlot:
|
|||||||
callback: Function that computes metrics from environment transitions
|
callback: Function that computes metrics from environment transitions
|
||||||
horizon_timesteps: The time horizon used for the live plots
|
horizon_timesteps: The time horizon used for the live plots
|
||||||
plot_names: List of plot titles
|
plot_names: List of plot titles
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DependencyNotInstalled: If matplotlib is not installed
|
||||||
"""
|
"""
|
||||||
deprecation(
|
deprecation(
|
||||||
"`PlayPlot` is marked as deprecated and will be removed in the near future."
|
"`PlayPlot` is marked as deprecated and will be removed in the near future."
|
||||||
|
@@ -20,6 +20,9 @@ def np_random(seed: Optional[int] = None) -> tuple[RandomNumberGenerator, Any]:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The generator and resulting seed
|
The generator and resulting seed
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Error: Seed must be a non-negative integer or omitted
|
||||||
"""
|
"""
|
||||||
if seed is not None and not (isinstance(seed, int) and 0 <= seed):
|
if seed is not None and not (isinstance(seed, int) and 0 <= seed):
|
||||||
raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}")
|
raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}")
|
||||||
@@ -175,6 +178,9 @@ def create_seed(a: Optional[Union[int, str]] = None, max_bytes: int = 8) -> int:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A seed
|
A seed
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Error: Invalid type for seed, expects None or str or int
|
||||||
"""
|
"""
|
||||||
deprecation(
|
deprecation(
|
||||||
"Function `create_seed(a, max_bytes)` is marked as deprecated and will be removed in the future. "
|
"Function `create_seed(a, max_bytes)` is marked as deprecated and will be removed in the future. "
|
||||||
|
@@ -73,19 +73,30 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
env_fns: Functions that create the environments.
|
env_fns: Functions that create the environments.
|
||||||
observation_space: Observation space of a single environment. If ``None``, then the observation space of the first environment is taken.
|
observation_space: Observation space of a single environment. If ``None``,
|
||||||
action_space: Action space of a single environment. If ``None``, then the action space of the first environment is taken.
|
then the observation space of the first environment is taken.
|
||||||
shared_memory: If ``True``, then the observations from the worker processes are communicated back through shared variables. This can improve the efficiency if the observations are large (e.g. images).
|
action_space: Action space of a single environment. If ``None``,
|
||||||
copy: If ``True``, then the :meth:`~AsyncVectorEnv.reset` and :meth:`~AsyncVectorEnv.step` methods return a copy of the observations.
|
then the action space of the first environment is taken.
|
||||||
|
shared_memory: If ``True``, then the observations from the worker processes are communicated back through
|
||||||
|
shared variables. This can improve the efficiency if the observations are large (e.g. images).
|
||||||
|
copy: If ``True``, then the :meth:`~AsyncVectorEnv.reset` and :meth:`~AsyncVectorEnv.step` methods
|
||||||
|
return a copy of the observations.
|
||||||
context: Context for `multiprocessing`_. If ``None``, then the default context is used.
|
context: Context for `multiprocessing`_. If ``None``, then the default context is used.
|
||||||
daemon: If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they will quit if the head process quits. However, ``daemon=True`` prevents subprocesses to spawn children, so for some environments you may want to have it set to ``False``.
|
daemon: If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they will quit if
|
||||||
worker: If set, then use that worker in a subprocess instead of a default one. Can be useful to override some inner vector env logic, for instance, how resets on done are handled.
|
the head process quits. However, ``daemon=True`` prevents subprocesses to spawn children,
|
||||||
|
so for some environments you may want to have it set to ``False``.
|
||||||
|
worker: If set, then use that worker in a subprocess instead of a default one.
|
||||||
|
Can be useful to override some inner vector env logic, for instance, how resets on done are handled.
|
||||||
|
|
||||||
Warnings: worker is an advanced mode option. It provides a high degree of flexibility and a high chance to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start from the code for ``_worker`` (or ``_worker_shared_memory``) method, and add changes.
|
Warnings: worker is an advanced mode option. It provides a high degree of flexibility and a high chance
|
||||||
|
to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start
|
||||||
|
from the code for ``_worker`` (or ``_worker_shared_memory``) method, and add changes.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If the observation space of some sub-environment does not match observation_space (or, by default, the observation space of the first sub-environment).
|
RuntimeError: If the observation space of some sub-environment does not match observation_space
|
||||||
ValueError: If observation_space is a custom space (i.e. not a default space in Gym, such as gym.spaces.Box, gym.spaces.Discrete, or gym.spaces.Dict) and shared_memory is True.
|
(or, by default, the observation space of the first sub-environment).
|
||||||
|
ValueError: If observation_space is a custom space (i.e. not a default space in Gym,
|
||||||
|
such as gym.spaces.Box, gym.spaces.Discrete, or gym.spaces.Dict) and shared_memory is True.
|
||||||
"""
|
"""
|
||||||
ctx = mp.get_context(context)
|
ctx = mp.get_context(context)
|
||||||
self.env_fns = env_fns
|
self.env_fns = env_fns
|
||||||
@@ -163,6 +174,9 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
seed: The seeds use with the environments
|
seed: The seeds use with the environments
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AlreadyPendingCallError: Calling `seed` while waiting for a pending call to complete
|
||||||
"""
|
"""
|
||||||
super().seed(seed=seed)
|
super().seed(seed=seed)
|
||||||
self._assert_is_running()
|
self._assert_is_running()
|
||||||
@@ -382,6 +396,10 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
name: Name of the method or property to call.
|
name: Name of the method or property to call.
|
||||||
*args: Arguments to apply to the method call.
|
*args: Arguments to apply to the method call.
|
||||||
**kwargs: Keyword arguments to apply to the method call.
|
**kwargs: Keyword arguments to apply to the method call.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
|
||||||
|
AlreadyPendingCallError: Calling `call_async` while waiting for a pending call to complete
|
||||||
"""
|
"""
|
||||||
self._assert_is_running()
|
self._assert_is_running()
|
||||||
if self._state != AsyncState.DEFAULT:
|
if self._state != AsyncState.DEFAULT:
|
||||||
@@ -399,10 +417,15 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
"""Calls all parent pipes and waits for the results.
|
"""Calls all parent pipes and waits for the results.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
timeout: Number of seconds before the call to `step_wait` times out. If `None` (default), the call to `step_wait` never times out.
|
timeout: Number of seconds before the call to `step_wait` times out.
|
||||||
|
If `None` (default), the call to `step_wait` never times out.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of the results of the individual calls to the method or property for each environment.
|
List of the results of the individual calls to the method or property for each environment.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NoAsyncCallError: Calling `call_wait` without any prior call to `call_async`.
|
||||||
|
TimeoutError: The call to `call_wait` has timed out after timeout second(s).
|
||||||
"""
|
"""
|
||||||
self._assert_is_running()
|
self._assert_is_running()
|
||||||
if self._state != AsyncState.WAITING_CALL:
|
if self._state != AsyncState.WAITING_CALL:
|
||||||
@@ -431,6 +454,10 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
values: Values of the property to be set to. If ``values`` is a list or
|
values: Values of the property to be set to. If ``values`` is a list or
|
||||||
tuple, then it corresponds to the values for each individual
|
tuple, then it corresponds to the values for each individual
|
||||||
environment, otherwise a single value is set for all environments.
|
environment, otherwise a single value is set for all environments.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Values must be a list or tuple with length equal to the number of environments.
|
||||||
|
AlreadyPendingCallError: Calling `set_attr` while waiting for a pending call to complete.
|
||||||
"""
|
"""
|
||||||
self._assert_is_running()
|
self._assert_is_running()
|
||||||
if not isinstance(values, (list, tuple)):
|
if not isinstance(values, (list, tuple)):
|
||||||
|
@@ -39,12 +39,15 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
env_fns: iterable of callable functions that create the environments.
|
env_fns: iterable of callable functions that create the environments.
|
||||||
observation_space: Observation space of a single environment. If ``None``, then the observation space of the first environment is taken.
|
observation_space: Observation space of a single environment. If ``None``,
|
||||||
action_space: Action space of a single environment. If ``None``, then the action space of the first environment is taken.
|
then the observation space of the first environment is taken.
|
||||||
|
action_space: Action space of a single environment. If ``None``,
|
||||||
|
then the action space of the first environment is taken.
|
||||||
copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations.
|
copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If the observation space of some sub-environment does not match observation_space (or, by default, the observation space of the first sub-environment).
|
RuntimeError: If the observation space of some sub-environment does not match observation_space
|
||||||
|
(or, by default, the observation space of the first sub-environment).
|
||||||
"""
|
"""
|
||||||
self.env_fns = env_fns
|
self.env_fns = env_fns
|
||||||
self.envs = [env_fn() for env_fn in env_fns]
|
self.envs = [env_fn() for env_fn in env_fns]
|
||||||
@@ -195,6 +198,9 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
values: Values of the property to be set to. If ``values`` is a list or
|
values: Values of the property to be set to. If ``values`` is a list or
|
||||||
tuple, then it corresponds to the values for each individual
|
tuple, then it corresponds to the values for each individual
|
||||||
environment, otherwise, a single value is set for all environments.
|
environment, otherwise, a single value is set for all environments.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Values must be a list or tuple with length equal to the number of environments.
|
||||||
"""
|
"""
|
||||||
if not isinstance(values, (list, tuple)):
|
if not isinstance(values, (list, tuple)):
|
||||||
values = [values for _ in range(self.num_envs)]
|
values = [values for _ in range(self.num_envs)]
|
||||||
|
@@ -39,6 +39,9 @@ def clear_mpi_env_vars():
|
|||||||
|
|
||||||
This context manager is a hacky way to clear those environment variables
|
This context manager is a hacky way to clear those environment variables
|
||||||
temporarily such as when we are starting multiprocessing Processes.
|
temporarily such as when we are starting multiprocessing Processes.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Yields for the context manager
|
||||||
"""
|
"""
|
||||||
removed_environment = {}
|
removed_environment = {}
|
||||||
for k, v in list(os.environ.items()):
|
for k, v in list(os.environ.items()):
|
||||||
|
@@ -33,6 +33,9 @@ def concatenate(
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The output object. This object is a (possibly nested) numpy array.
|
The output object. This object is a (possibly nested) numpy array.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Space is not a valid :class:`gym.Space` instance
|
||||||
"""
|
"""
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Space of type `{type(space)}` is not a valid `gym.Space` instance."
|
f"Space of type `{type(space)}` is not a valid `gym.Space` instance."
|
||||||
@@ -95,6 +98,9 @@ def create_empty_array(
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The output object. This object is a (possibly nested) numpy array.
|
The output object. This object is a (possibly nested) numpy array.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Space is not a valid :class:`gym.Space` instance
|
||||||
"""
|
"""
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Space of type `{type(space)}` is not a valid `gym.Space` instance."
|
f"Space of type `{type(space)}` is not a valid `gym.Space` instance."
|
||||||
|
@@ -28,6 +28,9 @@ def create_shared_memory(
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
shared_memory for the shared object across processes.
|
shared_memory for the shared object across processes.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CustomSpaceError: Space is not a valid :class:`gym.Space` instance
|
||||||
"""
|
"""
|
||||||
raise CustomSpaceError(
|
raise CustomSpaceError(
|
||||||
"Cannot create a shared memory for space with "
|
"Cannot create a shared memory for space with "
|
||||||
@@ -86,6 +89,8 @@ def read_from_shared_memory(
|
|||||||
Returns:
|
Returns:
|
||||||
Batch of observations as a (possibly nested) numpy array.
|
Batch of observations as a (possibly nested) numpy array.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CustomSpaceError: Space is not a valid :class:`gym.Space` instance
|
||||||
"""
|
"""
|
||||||
raise CustomSpaceError(
|
raise CustomSpaceError(
|
||||||
"Cannot read from a shared memory for space with "
|
"Cannot read from a shared memory for space with "
|
||||||
@@ -137,7 +142,11 @@ def write_to_shared_memory(
|
|||||||
space: Observation space of a single environment in the vectorized environment.
|
space: Observation space of a single environment in the vectorized environment.
|
||||||
index: Index of the environment (must be in `[0, num_envs)`).
|
index: Index of the environment (must be in `[0, num_envs)`).
|
||||||
value: Observation of the single environment to write to shared memory.
|
value: Observation of the single environment to write to shared memory.
|
||||||
shared_memory: Shared object across processes. This contains the observations from the vectorized environment. This object is created with `create_shared_memory`.
|
shared_memory: Shared object across processes. This contains the observations from the vectorized environment.
|
||||||
|
This object is created with `create_shared_memory`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
CustomSpaceError: Space is not a valid :class:`gym.Space` instance
|
||||||
"""
|
"""
|
||||||
raise CustomSpaceError(
|
raise CustomSpaceError(
|
||||||
"Cannot write to a shared memory for space with "
|
"Cannot write to a shared memory for space with "
|
||||||
|
@@ -33,6 +33,9 @@ def batch_space(space: Space, n: int = 1) -> Space:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Space (e.g. the observation space) for a batch of environments in the vectorized environment.
|
Space (e.g. the observation space) for a batch of environments in the vectorized environment.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Cannot batch space that is not a valid :class:`gym.Space` instance
|
||||||
"""
|
"""
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Cannot batch space with type `{type(space)}`. The space must be a valid `gym.Space` instance."
|
f"Cannot batch space with type `{type(space)}`. The space must be a valid `gym.Space` instance."
|
||||||
@@ -147,9 +150,12 @@ def iterate(space: Space, items) -> Iterator:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Iterator over the elements in `items`.
|
Iterator over the elements in `items`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Space is not an instance of :class:`gym.Space`
|
||||||
"""
|
"""
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Space of type `{type(space)}` is not a valid `gym.Space` " "instance."
|
f"Space of type `{type(space)}` is not a valid `gym.Space` instance."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -56,7 +56,13 @@ class VectorEnv(gym.Env):
|
|||||||
):
|
):
|
||||||
"""Reset the sub-environments asynchronously.
|
"""Reset the sub-environments asynchronously.
|
||||||
|
|
||||||
This method will return ``None``. A call to :meth:`reset_async` should be followed by a call to :meth:`reset_wait` to retrieve the results.
|
This method will return ``None``. A call to :meth:`reset_async` should be followed
|
||||||
|
by a call to :meth:`reset_wait` to retrieve the results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: The reset seed
|
||||||
|
return_info: If to return info
|
||||||
|
options: Reset options
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -69,8 +75,18 @@ class VectorEnv(gym.Env):
|
|||||||
"""Retrieves the results of a :meth:`reset_async` call.
|
"""Retrieves the results of a :meth:`reset_async` call.
|
||||||
|
|
||||||
A call to this method must always be preceded by a call to :meth:`reset_async`.
|
A call to this method must always be preceded by a call to :meth:`reset_async`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: The reset seed
|
||||||
|
return_info: If to return info
|
||||||
|
options: Reset options
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The results from :meth:`reset_async`
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: VectorEnv does not implement function
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def reset(
|
def reset(
|
||||||
self,
|
self,
|
||||||
@@ -96,15 +112,22 @@ class VectorEnv(gym.Env):
|
|||||||
"""Asynchronously performs steps in the sub-environments.
|
"""Asynchronously performs steps in the sub-environments.
|
||||||
|
|
||||||
The results can be retrieved via a call to :meth:`step_wait`.
|
The results can be retrieved via a call to :meth:`step_wait`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actions: The actions to take asynchronously
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def step_wait(self, **kwargs):
|
def step_wait(self, **kwargs):
|
||||||
"""Retrieves the results of a :meth:`step_async` call.
|
"""Retrieves the results of a :meth:`step_async` call.
|
||||||
|
|
||||||
A call to this method must always be preceded by a call to :meth:`step_async`.
|
A call to this method must always be preceded by a call to :meth:`step_async`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Additional keywords for vector implementation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The results from the :meth:`step_async` call
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def step(self, actions):
|
def step(self, actions):
|
||||||
"""Take an action for each parallel environment.
|
"""Take an action for each parallel environment.
|
||||||
@@ -120,11 +143,9 @@ class VectorEnv(gym.Env):
|
|||||||
|
|
||||||
def call_async(self, name, *args, **kwargs):
|
def call_async(self, name, *args, **kwargs):
|
||||||
"""Calls a method name for each parallel environment asynchronously."""
|
"""Calls a method name for each parallel environment asynchronously."""
|
||||||
pass
|
|
||||||
|
|
||||||
def call_wait(self, **kwargs):
|
def call_wait(self, **kwargs) -> list[Any]:
|
||||||
"""After calling a method in :meth:`call_async`, this function collects the results."""
|
"""After calling a method in :meth:`call_async`, this function collects the results."""
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def call(self, name: str, *args, **kwargs) -> list[Any]:
|
def call(self, name: str, *args, **kwargs) -> list[Any]:
|
||||||
"""Call a method, or get a property, from each parallel environment.
|
"""Call a method, or get a property, from each parallel environment.
|
||||||
@@ -160,7 +181,6 @@ class VectorEnv(gym.Env):
|
|||||||
tuple, then it corresponds to the values for each individual environment, otherwise a single value
|
tuple, then it corresponds to the values for each individual environment, otherwise a single value
|
||||||
is set for all environments.
|
is set for all environments.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def close_extras(self, **kwargs):
|
def close_extras(self, **kwargs):
|
||||||
"""Clean up the extra resources e.g. beyond what's in this base class."""
|
"""Clean up the extra resources e.g. beyond what's in this base class."""
|
||||||
@@ -180,6 +200,8 @@ class VectorEnv(gym.Env):
|
|||||||
Notes:
|
Notes:
|
||||||
This will be automatically called when garbage collected or program exited.
|
This will be automatically called when garbage collected or program exited.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Keyword arguments passed to :meth:`close_extras`
|
||||||
"""
|
"""
|
||||||
if self.closed:
|
if self.closed:
|
||||||
return
|
return
|
||||||
@@ -260,8 +282,12 @@ class VectorEnv(gym.Env):
|
|||||||
if not getattr(self, "closed", True):
|
if not getattr(self, "closed", True):
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
"""Returns a string representation of the vector environment using the class name, number of environments and environment spec id."""
|
"""Returns a string representation of the vector environment.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A string containing the class name, number of environments and environment spec id
|
||||||
|
"""
|
||||||
if self.spec is None:
|
if self.spec is None:
|
||||||
return f"{self.__class__.__name__}({self.num_envs})"
|
return f"{self.__class__.__name__}({self.num_envs})"
|
||||||
else:
|
else:
|
||||||
|
@@ -54,6 +54,10 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
grayscale observations to make them 3-dimensional.
|
grayscale observations to make them 3-dimensional.
|
||||||
scale_obs (bool): if True, then observation normalized in range [0,1) is returned. It also limits memory
|
scale_obs (bool): if True, then observation normalized in range [0,1) is returned. It also limits memory
|
||||||
optimization benefits of FrameStack Wrapper.
|
optimization benefits of FrameStack Wrapper.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DependencyNotInstalled: opencv-python package not installed
|
||||||
|
ValueError: Disable frame-skipping in the original env
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
if cv2 is None:
|
if cv2 is None:
|
||||||
|
@@ -26,6 +26,9 @@ class LazyFrames:
|
|||||||
Args:
|
Args:
|
||||||
frames (list): The frames to convert to lazy frames
|
frames (list): The frames to convert to lazy frames
|
||||||
lz4_compress (bool): Use lz4 to compress the frames internally
|
lz4_compress (bool): Use lz4 to compress the frames internally
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DependencyNotInstalled: lz4 is not installed
|
||||||
"""
|
"""
|
||||||
self.frame_shape = tuple(frames[0].shape)
|
self.frame_shape = tuple(frames[0].shape)
|
||||||
self.shape = (len(frames),) + self.frame_shape
|
self.shape = (len(frames),) + self.frame_shape
|
||||||
|
@@ -46,6 +46,10 @@ class VideoRecorder:
|
|||||||
metadata (Optional[dict]): Contents to save to the metadata file.
|
metadata (Optional[dict]): Contents to save to the metadata file.
|
||||||
enabled (bool): Whether to actually record video, or just no-op (for convenience)
|
enabled (bool): Whether to actually record video, or just no-op (for convenience)
|
||||||
base_path (Optional[str]): Alternatively, path to the video file without extension, which will be added.
|
base_path (Optional[str]): Alternatively, path to the video file without extension, which will be added.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Error: You can pass at most one of `path` or `base_path`
|
||||||
|
Error: Invalid path given that must have a particular file extension
|
||||||
"""
|
"""
|
||||||
modes = env.metadata.get("render_modes", [])
|
modes = env.metadata.get("render_modes", [])
|
||||||
|
|
||||||
@@ -281,6 +285,9 @@ class TextEncoder:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
frame: A string or StringIO frame
|
frame: A string or StringIO frame
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidFrame: Wrong type for a frame, expects text frame to be a string or StringIO
|
||||||
"""
|
"""
|
||||||
if isinstance(frame, str):
|
if isinstance(frame, str):
|
||||||
string = frame
|
string = frame
|
||||||
@@ -366,6 +373,10 @@ class ImageEncoder:
|
|||||||
frame_shape: The expected frame shape, a tuple of height, weight and channels (3 or 4)
|
frame_shape: The expected frame shape, a tuple of height, weight and channels (3 or 4)
|
||||||
frames_per_sec: The number of frames per second the environment runs at
|
frames_per_sec: The number of frames per second the environment runs at
|
||||||
output_frames_per_sec: The output number of frames per second for the video
|
output_frames_per_sec: The output number of frames per second for the video
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidFrame: Expects frame to have shape (w,h,3) or (w,h,4)
|
||||||
|
DependencyNotInstalled: Found neither the ffmpeg nor avconv executables.
|
||||||
"""
|
"""
|
||||||
self.proc = None
|
self.proc = None
|
||||||
self.output_path = output_path
|
self.output_path = output_path
|
||||||
|
@@ -77,6 +77,7 @@ class PixelObservationWrapper(gym.ObservationWrapper):
|
|||||||
arrays.
|
arrays.
|
||||||
ValueError: If ``env``'s observation already contains any of the
|
ValueError: If ``env``'s observation already contains any of the
|
||||||
specified ``pixel_keys``.
|
specified ``pixel_keys``.
|
||||||
|
TypeError: When an unexpected pixel type is used
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
|
||||||
|
@@ -11,6 +11,12 @@ def capped_cubic_video_schedule(episode_id: int) -> bool:
|
|||||||
"""The default episode trigger.
|
"""The default episode trigger.
|
||||||
|
|
||||||
This function will trigger recordings at the episode indices 0, 1, 4, 8, 27, ..., :math:`k^3`, ..., 729, 1000, 2000, 3000, ...
|
This function will trigger recordings at the episode indices 0, 1, 4, 8, 27, ..., :math:`k^3`, ..., 729, 1000, 2000, 3000, ...
|
||||||
|
|
||||||
|
Args:
|
||||||
|
episode_id: The episode number
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If to apply a video schedule number
|
||||||
"""
|
"""
|
||||||
if episode_id < 1000:
|
if episode_id < 1000:
|
||||||
return int(round(episode_id ** (1.0 / 3))) ** 3 == episode_id
|
return int(round(episode_id ** (1.0 / 3))) ** 3 == episode_id
|
||||||
|
@@ -50,6 +50,9 @@ class ResizeObservation(gym.ObservationWrapper):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The reshaped observations
|
The reshaped observations
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DependencyNotInstalled: opencv-python is not installed
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import cv2
|
import cv2
|
||||||
|
Reference in New Issue
Block a user