mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 06:07:08 +00:00
Formated doctest and added more consistency (#281)
This commit is contained in:
@@ -41,7 +41,7 @@ The Git hooks can also be run manually with `pre-commit run --all-files`, and if
|
||||
**Note:** you may have to run `pre-commit run --all-files` manually a couple of times to make it pass when you commit, as each formatting tool will first format the code and fail the first time but should pass the second time.
|
||||
|
||||
Additionally, for pull requests, the project runs a number of tests for the whole project using [pytest](https://docs.pytest.org/en/latest/getting-started.html#install-pytest).
|
||||
These tests can be run locally with `pytest` in the root folder.
|
||||
These tests can be run locally with `pytest` in the root folder. If any doctest is modified, run `pytest --doctest-modules --doctest-continue-on-failure gymnasium` to check the changes.
|
||||
|
||||
## Docstrings
|
||||
|
||||
|
@@ -97,14 +97,15 @@ Wrappers are a convenient way to modify an existing environment without having t
|
||||
In order to wrap an environment, you must first initialize a base environment. Then you can pass this environment along with (possibly optional) parameters to the wrapper's constructor:
|
||||
|
||||
```python
|
||||
>>> import gymnasium
|
||||
>>> from gymnasium.wrappers import RescaleAction
|
||||
>>> base_env = gymnasium.make("BipedalWalker-v3")
|
||||
>>> base_env.action_space
|
||||
Box([-1. -1. -1. -1.], [1. 1. 1. 1.], (4,), float32)
|
||||
>>> wrapped_env = RescaleAction(base_env, min_action=0, max_action=1)
|
||||
>>> wrapped_env.action_space
|
||||
Box([0. 0. 0. 0.], [1. 1. 1. 1.], (4,), float32)
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.wrappers import FlattenObservation
|
||||
>>> env = gym.make("CarRacing-v2")
|
||||
>>> env.observation_space.shape
|
||||
(96, 96, 3)
|
||||
>>> wrapped_env = FlattenObservation(env)
|
||||
>>> wrapped_env.observation_space.shape
|
||||
(27648,)
|
||||
|
||||
```
|
||||
|
||||
Gymnasium already provides many commonly used wrappers for you. Some examples:
|
||||
@@ -120,9 +121,10 @@ If you have a wrapped environment, and you want to get the unwrapped environment
|
||||
|
||||
```python
|
||||
>>> wrapped_env
|
||||
<RescaleAction<TimeLimit<BipedalWalker<BipedalWalker-v3>>>>
|
||||
<FlattenObservation<TimeLimit<OrderEnforcing<PassiveEnvChecker<CarRacing<CarRacing-v2>>>>>>
|
||||
>>> wrapped_env.unwrapped
|
||||
<gymnasium.envs.box2d.bipedal_walker.BipedalWalker object at 0x7f87d70712d0>
|
||||
<gymnasium.envs.box2d.car_racing.CarRacing object at 0x7f04efcb8850>
|
||||
|
||||
```
|
||||
|
||||
## More information
|
||||
|
@@ -23,29 +23,23 @@ class CartPoleFunctional(
|
||||
):
|
||||
"""Cartpole but in jax and functional.
|
||||
|
||||
Example usage:
|
||||
|
||||
Example:
|
||||
>>> import jax
|
||||
>>> import jax.numpy as jnp
|
||||
>>> from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
|
||||
|
||||
>>> key = jax.random.PRNGKey(0)
|
||||
|
||||
>>> env = CartPoleFunctional({"x_init": 0.5})
|
||||
>>> state = env.initial(key)
|
||||
>>> print(state)
|
||||
[ 0.46532142 -0.27484107 0.13302994 -0.20361817]
|
||||
>>> print(env.transition(state, 0))
|
||||
[ 0.4598246 -0.6357784 0.12895757 0.1278053 ]
|
||||
|
||||
>>> env.transform(jax.jit)
|
||||
|
||||
>>> state = env.initial(key)
|
||||
>>> print(state)
|
||||
[ 0.46532142 -0.27484107 0.13302994 -0.20361817]
|
||||
>>> print(env.transition(state, 0))
|
||||
[ 0.4598246 -0.6357784 0.12895757 0.12780523]
|
||||
|
||||
>>> vkey = jax.random.split(key, 10)
|
||||
>>> env.transform(jax.vmap)
|
||||
>>> vstate = env.initial(vkey)
|
||||
|
@@ -125,13 +125,18 @@ class OrderEnforcingV0(gym.Wrapper):
|
||||
>>> from gymnasium.experimental.wrappers import OrderEnforcingV0
|
||||
>>> env = gym.make("CartPole-v1", render_mode="human")
|
||||
>>> env = OrderEnforcingV0(env)
|
||||
>>> env.step(0) # doctest: +SKIP
|
||||
>>> env.step(0)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
gymnasium.error.ResetNeeded: Cannot call env.step() before calling env.reset()
|
||||
>>> env.render() # doctest: +SKIP
|
||||
gymnasium.error.ResetNeeded('Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper.')
|
||||
>>> env.render()
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
gymnasium.error.ResetNeeded: Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper.
|
||||
>>> _ = env.reset()
|
||||
>>> env.render()
|
||||
>>> _ = env.step(0)
|
||||
>>> env.close()
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False):
|
||||
|
@@ -44,8 +44,8 @@ class LambdaObservationV0(gym.ObservationWrapper):
|
||||
>>> np.random.seed(0)
|
||||
>>> env = gym.make("CartPole-v1")
|
||||
>>> env = LambdaObservationV0(env, lambda obs: obs + 0.1 * np.random.random(obs.shape), env.observation_space)
|
||||
>>> env.reset(seed=42) # doctest: +SKIP
|
||||
(array([ 0.06199517, 0.0511615 , -0.04432538, 0.02694618]), {})
|
||||
>>> env.reset(seed=42)
|
||||
(array([0.08227695, 0.06540678, 0.09613613, 0.07422512]), {})
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -349,7 +349,6 @@ class RescaleObservationV0(LambdaObservationV0):
|
||||
>>> env = RescaleObservationV0(env, np.array([-2, -1, -10]), np.array([1, 0, 1]))
|
||||
>>> env.observation_space
|
||||
Box([ -2. -1. -10.], [1. 0. 1.], (3,), float32)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@@ -57,7 +57,7 @@ class LambdaRewardV0(gym.RewardWrapper):
|
||||
class ClipRewardV0(LambdaRewardV0):
|
||||
"""A wrapper that clips the rewards for an environment between an upper and lower bound.
|
||||
|
||||
Example with an upper and lower bound:
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.experimental.wrappers import ClipRewardV0
|
||||
>>> env = gym.make("CartPole-v1")
|
||||
|
@@ -294,23 +294,21 @@ class HumanRenderingV0(gym.Wrapper):
|
||||
>>> wrapped = HumanRenderingV0(env)
|
||||
>>> obs, _ = wrapped.reset() # This will start rendering to the screen
|
||||
|
||||
The wrapper can also be applied directly when the environment is instantiated, simply by passing
|
||||
``render_mode="human"`` to ``make``. The wrapper will only be applied if the environment does not
|
||||
implement human-rendering natively (i.e. ``render_mode`` does not contain ``"human"``).
|
||||
The wrapper can also be applied directly when the environment is instantiated, simply by passing
|
||||
``render_mode="human"`` to ``make``. The wrapper will only be applied if the environment does not
|
||||
implement human-rendering natively (i.e. ``render_mode`` does not contain ``"human"``).
|
||||
|
||||
Example:
|
||||
>>> env = gym.make("CartPoleJax-v1", render_mode="human") # CartPoleJax-v1 doesn't implement human-rendering natively
|
||||
>>> obs, _ = env.reset() # This will start rendering to the screen
|
||||
|
||||
Warning: If the base environment uses ``render_mode="rgb_array_list"``, its (i.e. the *base environment's*) render method
|
||||
Warning: If the base environment uses ``render_mode="rgb_array_list"``, its (i.e. the *base environment's*) render method
|
||||
will always return an empty list:
|
||||
|
||||
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array_list")
|
||||
>>> wrapped = HumanRenderingV0(env)
|
||||
>>> obs, _ = wrapped.reset()
|
||||
>>> env.render() # env.render() will always return an empty list!
|
||||
[]
|
||||
|
||||
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array_list")
|
||||
>>> wrapped = HumanRenderingV0(env)
|
||||
>>> obs, _ = wrapped.reset()
|
||||
>>> env.render() # env.render() will always return an empty list!
|
||||
[]
|
||||
"""
|
||||
|
||||
def __init__(self, env):
|
||||
|
@@ -89,7 +89,7 @@ class TimeAwareObservationV0(gym.ObservationWrapper):
|
||||
>>> env.step(env.action_space.sample())[0]
|
||||
{'obs': array([ 0.02727336, -0.20172954, 0.03625453, 0.32351476], dtype=float32), 'time': 0.002}
|
||||
|
||||
Flatten observation space example:
|
||||
Flatten observation space example:
|
||||
>>> env = gym.make("CartPole-v1")
|
||||
>>> env = TimeAwareObservationV0(env, flatten=True)
|
||||
>>> env.observation_space
|
||||
@@ -100,7 +100,6 @@ class TimeAwareObservationV0(gym.ObservationWrapper):
|
||||
>>> env.step(env.action_space.sample())[0]
|
||||
array([ 0.02727336, -0.20172954, 0.03625453, 0.32351476, 0.002 ],
|
||||
dtype=float32)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@@ -16,14 +16,13 @@ class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
|
||||
|
||||
Elements of this space are (ordered) dictionaries of elements from the constituent spaces.
|
||||
|
||||
Example usage:
|
||||
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Dict, Discrete
|
||||
>>> observation_space = Dict({"position": Discrete(2), "velocity": Discrete(3)}, seed=42)
|
||||
>>> observation_space.sample()
|
||||
OrderedDict([('position', 0), ('velocity', 2)])
|
||||
|
||||
Example usage [nested]::
|
||||
With a nested dict:
|
||||
|
||||
>>> from gymnasium.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete
|
||||
>>> Dict( # doctest: +SKIP
|
||||
@@ -62,18 +61,19 @@ class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
|
||||
of spaces to :meth:`__init__` via the ``spaces`` argument, or you pass the spaces as separate
|
||||
keyword arguments (where you will need to avoid the keys ``spaces`` and ``seed``)
|
||||
|
||||
Example::
|
||||
|
||||
>>> from gymnasium.spaces import Box, Discrete
|
||||
>>> Dict({"position": Box(-1, 1, shape=(2,)), "color": Discrete(3)})
|
||||
Dict('color': Discrete(3), 'position': Box(-1.0, 1.0, (2,), float32))
|
||||
>>> Dict(position=Box(-1, 1, shape=(2,)), color=Discrete(3))
|
||||
Dict('position': Box(-1.0, 1.0, (2,), float32), 'color': Discrete(3))
|
||||
|
||||
Args:
|
||||
spaces: A dictionary of spaces. This specifies the structure of the :class:`Dict` space
|
||||
seed: Optionally, you can use this argument to seed the RNGs of the spaces that make up the :class:`Dict` space.
|
||||
**spaces_kwargs: If ``spaces`` is ``None``, you need to pass the constituent spaces as keyword arguments, as described above.
|
||||
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Dict, Box, Discrete
|
||||
>>> observation_space = Dict({"position": Box(-1, 1, shape=(2,)), "color": Discrete(3)}, seed=42)
|
||||
>>> observation_space.sample()
|
||||
OrderedDict([('color', 0), ('position', array([-0.3991573 , 0.21649833], dtype=float32))])
|
||||
>>> observation_space = Dict(position=Box(-1, 1, shape=(2,)), color=Discrete(3), seed=42)
|
||||
>>> observation_space.sample()
|
||||
OrderedDict([('position', array([0.6273108, 0.240238 ], dtype=float32)), ('color', 2)])
|
||||
"""
|
||||
# Convert the spaces into an OrderedDict
|
||||
if isinstance(spaces, collections.abc.Mapping) and not isinstance(
|
||||
|
@@ -13,12 +13,14 @@ class Discrete(Space[np.int64]):
|
||||
|
||||
This class represents a finite subset of integers, more specifically a set of the form :math:`\{ a, a+1, \dots, a+n-1 \}`.
|
||||
|
||||
Example::
|
||||
|
||||
>>> Discrete(2) # {0, 1}
|
||||
Discrete(2)
|
||||
>>> Discrete(3, start=-1) # {-1, 0, 1}
|
||||
Discrete(3, start=-1)
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Discrete
|
||||
>>> observation_space = Discrete(2, seed=42) # {0, 1}
|
||||
>>> observation_space.sample()
|
||||
0
|
||||
>>> observation_space = Discrete(3, start=-1, seed=42) # {-1, 0, 1}
|
||||
>>> observation_space.sample()
|
||||
-1
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@@ -29,9 +29,28 @@ class GraphInstance(NamedTuple):
|
||||
class Graph(Space[GraphInstance]):
|
||||
r"""A space representing graph information as a series of `nodes` connected with `edges` according to an adjacency matrix represented as a series of `edge_links`.
|
||||
|
||||
Example usage::
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Graph, Box, Discrete
|
||||
>>> observation_space = Graph(node_space=Box(low=-100, high=100, shape=(3,)), edge_space=Discrete(3), seed=42)
|
||||
>>> observation_space.sample()
|
||||
GraphInstance(nodes=array([[-12.224312 , 71.71958 , 39.473606 ],
|
||||
[-81.16453 , 95.12447 , 52.22794 ],
|
||||
[ 57.21286 , -74.37727 , -9.922812 ],
|
||||
[-25.840395 , 85.353 , 28.773024 ],
|
||||
[ 64.55232 , -11.317161 , -54.552258 ],
|
||||
[ 10.916958 , -87.23655 , 65.52624 ],
|
||||
[ 26.33288 , 51.61755 , -29.094807 ],
|
||||
[ 94.1396 , 78.62422 , 55.6767 ],
|
||||
[-61.072258 , -6.6557994, -91.23925 ],
|
||||
[-69.142105 , 36.60979 , 48.95243 ]], dtype=float32), edges=array([2, 0, 1, 1, 0, 0, 1, 0]), edge_links=array([[7, 5],
|
||||
[6, 9],
|
||||
[4, 1],
|
||||
[8, 6],
|
||||
[7, 0],
|
||||
[3, 7],
|
||||
[8, 4],
|
||||
[8, 8]]))
|
||||
|
||||
self.observation_space = spaces.Graph(node_space=space.Box(low=-100, high=100, shape=(3,)), edge_space=spaces.Discrete(3))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@@ -14,8 +14,8 @@ class MultiBinary(Space[npt.NDArray[np.int8]]):
|
||||
|
||||
Elements of this space are binary arrays of a shape that is fixed during construction.
|
||||
|
||||
Example Usage::
|
||||
|
||||
Example:
|
||||
>>> from gymnasium.spaces import MultiBinary
|
||||
>>> observation_space = MultiBinary(5, seed=42)
|
||||
>>> observation_space.sample()
|
||||
array([1, 0, 1, 0, 1], dtype=int8)
|
||||
|
@@ -30,12 +30,13 @@ class MultiDiscrete(Space[npt.NDArray[np.integer]]):
|
||||
Although this feature is rarely used, :class:`MultiDiscrete` spaces may also have several axes
|
||||
if ``nvec`` has several axes:
|
||||
|
||||
Example::
|
||||
|
||||
>> d = MultiDiscrete(np.array([[1, 2], [3, 4]]), seed=42)
|
||||
>> d.sample()
|
||||
Example:
|
||||
>>> from gymnasium.spaces import MultiDiscrete
|
||||
>>> import numpy as np
|
||||
>>> observation_space = MultiDiscrete(np.array([[1, 2], [3, 4]]), seed=42)
|
||||
>>> observation_space.sample()
|
||||
array([[0, 0],
|
||||
[2, 3]])
|
||||
[2, 2]])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@@ -17,13 +17,14 @@ class Sequence(Space[typing.Tuple[Any, ...]]):
|
||||
This space represents the set of tuples of the form :math:`(a_0, \dots, a_n)` where the :math:`a_i` belong
|
||||
to some space that is specified during initialization and the integer :math:`n` is not fixed
|
||||
|
||||
Example::
|
||||
>>> from gymnasium.spaces import Box
|
||||
>>> space = Sequence(Box(0, 1), seed=42)
|
||||
>>> space.sample() # doctest: +SKIP
|
||||
(array([0.6369617], dtype=float32),)
|
||||
>>> space.sample() # doctest: +SKIP
|
||||
(array([0.01652764], dtype=float32), array([0.8132702], dtype=float32),)
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Sequence, Box
|
||||
>>> observation_space = Sequence(Box(0, 1), seed=2)
|
||||
>>> observation_space.sample()
|
||||
(array([0.26161215], dtype=float32),)
|
||||
>>> observation_space = Sequence(Box(0, 1), seed=0)
|
||||
>>> observation_space.sample()
|
||||
(array([0.6369617], dtype=float32), array([0.26978672], dtype=float32), array([0.04097353], dtype=float32))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@@ -17,7 +17,8 @@ alphanumeric: frozenset[str] = frozenset(
|
||||
class Text(Space[str]):
|
||||
r"""A space representing a string comprised of characters from a given charset.
|
||||
|
||||
Example::
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Text
|
||||
>>> # {"", "B5", "hello", ...}
|
||||
>>> Text(5)
|
||||
Text(1, 5, characters=0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz)
|
||||
|
@@ -15,9 +15,8 @@ class Tuple(Space[typing.Tuple[Any, ...]], typing.Sequence[Any]):
|
||||
|
||||
Elements of this space are tuples of elements of the constituent spaces.
|
||||
|
||||
Example usage::
|
||||
|
||||
>>> from gymnasium.spaces import Box, Discrete
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Tuple, Box, Discrete
|
||||
>>> observation_space = Tuple((Discrete(2), Box(-1, 1, shape=(2,))), seed=42)
|
||||
>>> observation_space.sample()
|
||||
(0, array([-0.3991573 , 0.21649833], dtype=float32))
|
||||
|
@@ -33,13 +33,6 @@ from gymnasium.spaces import (
|
||||
def flatdim(space: Space[Any]) -> int:
|
||||
"""Return the number of dimensions a flattened equivalent of this space would have.
|
||||
|
||||
Example usage::
|
||||
|
||||
>>> from gymnasium.spaces import Discrete, Dict
|
||||
>>> space = Dict({"position": Discrete(2), "velocity": Discrete(3)})
|
||||
>>> flatdim(space)
|
||||
5
|
||||
|
||||
Args:
|
||||
space: The space to return the number of dimensions of the flattened spaces
|
||||
|
||||
@@ -49,6 +42,12 @@ def flatdim(space: Space[Any]) -> int:
|
||||
Raises:
|
||||
NotImplementedError: if the space is not defined in :mod:`gym.spaces`.
|
||||
ValueError: if the space cannot be flattened into a :class:`gymnasium.spaces.Box`
|
||||
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Dict, Discrete
|
||||
>>> space = Dict({"position": Discrete(2), "velocity": Discrete(3)})
|
||||
>>> flatdim(space)
|
||||
5
|
||||
"""
|
||||
if space.is_np_flattenable is False:
|
||||
raise ValueError(
|
||||
@@ -117,19 +116,6 @@ def flatten(space: Space[T], x: T) -> FlatType:
|
||||
This is useful when e.g. points from spaces must be passed to a neural
|
||||
network, which only understands flat arrays of floats.
|
||||
|
||||
Example usage::
|
||||
>>> from gymnasium.spaces import Box, Discrete, Tuple
|
||||
>>> space = Box(0, 1, shape=(3, 5))
|
||||
>>> flatten(space, space.sample()).shape
|
||||
(15,)
|
||||
>>> space = Discrete(4)
|
||||
>>> flatten(space, 2)
|
||||
array([0, 0, 1, 0])
|
||||
>>> space = Tuple((Box(0, 1, shape=(2,)), Box(0, 1, shape=(3,)), Discrete(3)))
|
||||
>>> example = ((.5, .25), (1., 0., .2), 1)
|
||||
>>> flatten(space, example)
|
||||
array([0.5 , 0.25, 1. , 0. , 0.2 , 0. , 1. , 0. ])
|
||||
|
||||
Args:
|
||||
space: The space that ``x`` is flattened by
|
||||
x: The value to flatten
|
||||
@@ -151,6 +137,19 @@ def flatten(space: Space[T], x: T) -> FlatType:
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the space is not defined in :mod:`gymnasium.spaces`.
|
||||
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Box, Discrete, Tuple
|
||||
>>> space = Box(0, 1, shape=(3, 5))
|
||||
>>> flatten(space, space.sample()).shape
|
||||
(15,)
|
||||
>>> space = Discrete(4)
|
||||
>>> flatten(space, 2)
|
||||
array([0, 0, 1, 0])
|
||||
>>> space = Tuple((Box(0, 1, shape=(2,)), Box(0, 1, shape=(3,)), Discrete(3)))
|
||||
>>> example = ((.5, .25), (1., 0., .2), 1)
|
||||
>>> flatten(space, example)
|
||||
array([0.5 , 0.25, 1. , 0. , 0.2 , 0. , 1. , 0. ])
|
||||
"""
|
||||
raise NotImplementedError(f"Unknown space: `{space}`")
|
||||
|
||||
@@ -388,7 +387,17 @@ def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
|
||||
a Box, and the results may not be integers or one-hot encodings. This may result in
|
||||
errors or non-uniform sampling.
|
||||
|
||||
Example::
|
||||
Args:
|
||||
space: The space to flatten
|
||||
|
||||
Returns:
|
||||
A flattened Box
|
||||
|
||||
Raises:
|
||||
NotImplementedError: if the space is not defined in :mod:`gymnasium.spaces`.
|
||||
|
||||
Example:
|
||||
Flatten spaces.Box:
|
||||
>>> from gymnasium.spaces import Box
|
||||
>>> box = Box(0.0, 1.0, shape=(3, 4, 5))
|
||||
>>> box
|
||||
@@ -398,15 +407,15 @@ def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
|
||||
>>> flatten(box, box.sample()) in flatten_space(box)
|
||||
True
|
||||
|
||||
Example that flattens a discrete space::
|
||||
Flatten spaces.Discrete:
|
||||
>>> from gymnasium.spaces import Discrete
|
||||
>>> discrete = Discrete(5)
|
||||
>>> flatten_space(discrete)
|
||||
Box(0, 1, (5,), int64)
|
||||
>>> flatten(box, box.sample()) in flatten_space(box)
|
||||
>>> flatten(discrete, discrete.sample()) in flatten_space(discrete)
|
||||
True
|
||||
|
||||
Example that recursively flattens a dict::
|
||||
Flatten spaces.Dict:
|
||||
>>> from gymnasium.spaces import Dict, Discrete, Box
|
||||
>>> space = Dict({"position": Discrete(2), "velocity": Box(0, 1, shape=(2, 2))})
|
||||
>>> flatten_space(space)
|
||||
@@ -414,23 +423,13 @@ def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
|
||||
>>> flatten(space, space.sample()) in flatten_space(space)
|
||||
True
|
||||
|
||||
|
||||
Example that flattens a graph::
|
||||
|
||||
Flatten spaces.Graph:
|
||||
>>> from gymnasium.spaces import Graph, Discrete, Box
|
||||
>>> space = Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5))
|
||||
>>> flatten_space(space)
|
||||
Graph(Box(-100.0, 100.0, (12,), float32), Box(0, 1, (5,), int64))
|
||||
>>> flatten(space, space.sample()) in flatten_space(space)
|
||||
True
|
||||
|
||||
Args:
|
||||
space: The space to flatten
|
||||
|
||||
Returns:
|
||||
A flattened Box
|
||||
|
||||
Raises:
|
||||
NotImplementedError: if the space is not defined in :mod:`gymnasium.spaces`.
|
||||
"""
|
||||
raise NotImplementedError(f"Unknown space: `{space}`")
|
||||
|
||||
|
@@ -4,7 +4,7 @@
|
||||
class EzPickle:
|
||||
"""Objects that are pickled and unpickled via their constructor arguments.
|
||||
|
||||
Example::
|
||||
Example:
|
||||
>>> class Dog(Animal, EzPickle): # doctest: +SKIP
|
||||
... def __init__(self, furcolor, tailkind="bushy"):
|
||||
... Animal.__init__()
|
||||
|
@@ -157,36 +157,6 @@ def play(
|
||||
):
|
||||
"""Allows one to play the game using keyboard.
|
||||
|
||||
Example::
|
||||
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.utils.play import play
|
||||
>>> play(gym.make("CarRacing-v2", render_mode="rgb_array"), keys_to_action={ # doctest: +SKIP
|
||||
... "w": np.array([0, 0.7, 0]),
|
||||
... "a": np.array([-1, 0, 0]),
|
||||
... "s": np.array([0, 0, 1]),
|
||||
... "d": np.array([1, 0, 0]),
|
||||
... "wa": np.array([-1, 0.7, 0]),
|
||||
... "dw": np.array([1, 0.7, 0]),
|
||||
... "ds": np.array([1, 0, 1]),
|
||||
... "as": np.array([-1, 0, 1]),
|
||||
... }, noop=np.array([0,0,0]))
|
||||
|
||||
Above code works also if the environment is wrapped, so it's particularly useful in
|
||||
verifying that the frame-level preprocessing does not render the game
|
||||
unplayable.
|
||||
|
||||
If you wish to plot real time statistics as you play, you can use
|
||||
:class:`gym.utils.play.PlayPlot`. Here's a sample code for plotting the reward
|
||||
for last 150 steps.
|
||||
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.utils.play import PlayPlot, play
|
||||
>>> def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info):
|
||||
... return [rew,]
|
||||
>>> plotter = PlayPlot(callback, 150, ["reward"]) # doctest: +SKIP
|
||||
>>> play(gym.make("CartPole-v1"), callback=plotter.callback) # doctest: +SKIP
|
||||
|
||||
Args:
|
||||
env: Environment to use for playing.
|
||||
transpose: If this is ``True``, the output of observation is transposed. Defaults to ``True``.
|
||||
@@ -233,6 +203,35 @@ def play(
|
||||
If ``None``, default ``key_to_action`` mapping for that environment is used, if provided.
|
||||
seed: Random seed used when resetting the environment. If None, no seed is used.
|
||||
noop: The action used when no key input has been entered, or the entered key combination is unknown.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.utils.play import play
|
||||
>>> play(gym.make("CarRacing-v2", render_mode="rgb_array"), keys_to_action={ # doctest: +SKIP
|
||||
... "w": np.array([0, 0.7, 0]),
|
||||
... "a": np.array([-1, 0, 0]),
|
||||
... "s": np.array([0, 0, 1]),
|
||||
... "d": np.array([1, 0, 0]),
|
||||
... "wa": np.array([-1, 0.7, 0]),
|
||||
... "dw": np.array([1, 0.7, 0]),
|
||||
... "ds": np.array([1, 0, 1]),
|
||||
... "as": np.array([-1, 0, 1]),
|
||||
... }, noop=np.array([0,0,0]))
|
||||
|
||||
Above code works also if the environment is wrapped, so it's particularly useful in
|
||||
verifying that the frame-level preprocessing does not render the game
|
||||
unplayable.
|
||||
|
||||
If you wish to plot real time statistics as you play, you can use
|
||||
:class:`gym.utils.play.PlayPlot`. Here's a sample code for plotting the reward
|
||||
for last 150 steps.
|
||||
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.utils.play import PlayPlot, play
|
||||
>>> def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info):
|
||||
... return [rew,]
|
||||
>>> plotter = PlayPlot(callback, 150, ["reward"]) # doctest: +SKIP
|
||||
>>> play(gym.make("CartPole-v1"), callback=plotter.callback) # doctest: +SKIP
|
||||
"""
|
||||
env.reset(seed=seed)
|
||||
|
||||
|
@@ -148,7 +148,7 @@ def step_api_compatibility(
|
||||
Returns:
|
||||
step_returns (tuple): Depending on `output_truncation_bool` bool, it can return `(obs, rew, done, info)` or `(obs, rew, terminated, truncated, info)`
|
||||
|
||||
Examples:
|
||||
Example:
|
||||
This function can be used to ensure compatibility in step interfaces with conflicting API. Eg. if env is written in old API,
|
||||
wrapper is written in new API, and the final step output is desired to be in old API.
|
||||
|
||||
@@ -162,8 +162,6 @@ def step_api_compatibility(
|
||||
>>> _ = vec_env.reset()
|
||||
>>> obs, rewards, dones, infos = step_api_compatibility(vec_env.step([0]), is_vector_env=True, output_truncation_bool=False)
|
||||
>>> obs, rewards, terminated, truncated, info = step_api_compatibility(vec_env.step([0]), is_vector_env=True, output_truncation_bool=True)
|
||||
|
||||
|
||||
"""
|
||||
if output_truncation_bool:
|
||||
return convert_to_terminated_truncated_step_api(step_returns, is_vector_env)
|
||||
|
@@ -21,16 +21,6 @@ def make(
|
||||
) -> VectorEnv:
|
||||
"""Create a vectorized environment from multiple copies of an environment, from its id.
|
||||
|
||||
Example::
|
||||
|
||||
>>> import gymnasium as gym
|
||||
>>> env = gym.vector.make('CartPole-v1', num_envs=3)
|
||||
>>> env.reset(seed=42)
|
||||
(array([[ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ],
|
||||
[ 0.01522993, -0.04562247, -0.04799704, 0.03392126],
|
||||
[-0.03774345, -0.02418869, -0.00942293, 0.0469184 ]],
|
||||
dtype=float32), {})
|
||||
|
||||
Args:
|
||||
id: The environment ID. This must be a valid ID from the registry.
|
||||
num_envs: Number of copies of the environment.
|
||||
@@ -42,6 +32,15 @@ def make(
|
||||
|
||||
Returns:
|
||||
The vectorized environment.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> env = gym.vector.make('CartPole-v1', num_envs=3)
|
||||
>>> env.reset(seed=42)
|
||||
(array([[ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ],
|
||||
[ 0.01522993, -0.04562247, -0.04799704, 0.03392126],
|
||||
[-0.03774345, -0.02418869, -0.00942293, 0.0469184 ]],
|
||||
dtype=float32), {})
|
||||
"""
|
||||
|
||||
def create_env(env_num: int) -> Callable[[], Env]:
|
||||
|
@@ -45,8 +45,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
|
||||
It uses ``multiprocessing`` processes, and pipes for communication.
|
||||
|
||||
Example::
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> env = gym.vector.AsyncVectorEnv([
|
||||
... lambda: gym.make("Pendulum-v1", g=9.81),
|
||||
|
@@ -16,8 +16,7 @@ __all__ = ["SyncVectorEnv"]
|
||||
class SyncVectorEnv(VectorEnv):
|
||||
"""Vectorized environment that serially runs multiple environments.
|
||||
|
||||
Example::
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> env = gym.vector.SyncVectorEnv([
|
||||
... lambda: gym.make("Pendulum-v1", g=9.81),
|
||||
|
@@ -25,18 +25,6 @@ def concatenate(
|
||||
) -> Union[tuple, dict, np.ndarray]:
|
||||
"""Concatenate multiple samples from space into a single object.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from gymnasium.spaces import Box
|
||||
>>> import numpy as np
|
||||
>>> space = Box(low=0, high=1, shape=(3,), seed=42, dtype=np.float32)
|
||||
>>> out = np.zeros((2, 3), dtype=np.float32)
|
||||
>>> items = [space.sample() for _ in range(2)]
|
||||
>>> concatenate(space, items, out)
|
||||
array([[0.77395606, 0.43887845, 0.85859793],
|
||||
[0.697368 , 0.09417735, 0.97562236]], dtype=float32)
|
||||
|
||||
|
||||
Args:
|
||||
space: Observation space of a single environment in the vectorized environment.
|
||||
items: Samples to be concatenated.
|
||||
@@ -47,6 +35,16 @@ def concatenate(
|
||||
|
||||
Raises:
|
||||
ValueError: Space is not a valid :class:`gym.Space` instance
|
||||
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Box
|
||||
>>> import numpy as np
|
||||
>>> space = Box(low=0, high=1, shape=(3,), seed=42, dtype=np.float32)
|
||||
>>> out = np.zeros((2, 3), dtype=np.float32)
|
||||
>>> items = [space.sample() for _ in range(2)]
|
||||
>>> concatenate(space, items, out)
|
||||
array([[0.77395606, 0.43887845, 0.85859793],
|
||||
[0.697368 , 0.09417735, 0.97562236]], dtype=float32)
|
||||
"""
|
||||
raise ValueError(
|
||||
f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance."
|
||||
@@ -90,20 +88,6 @@ def create_empty_array(
|
||||
) -> Union[tuple, dict, np.ndarray]:
|
||||
"""Create an empty (possibly nested) numpy array.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from gymnasium.spaces import Box, Dict
|
||||
>>> import numpy as np
|
||||
>>> space = Dict({
|
||||
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
|
||||
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
|
||||
>>> create_empty_array(space, n=2, fn=np.zeros)
|
||||
OrderedDict([('position', array([[0., 0., 0.],
|
||||
[0., 0., 0.]], dtype=float32)), ('velocity', array([[0., 0.],
|
||||
[0., 0.]], dtype=float32))])
|
||||
|
||||
|
||||
|
||||
Args:
|
||||
space: Observation space of a single environment in the vectorized environment.
|
||||
n: Number of environments in the vectorized environment. If `None`, creates an empty sample from `space`.
|
||||
@@ -114,6 +98,17 @@ def create_empty_array(
|
||||
|
||||
Raises:
|
||||
ValueError: Space is not a valid :class:`gym.Space` instance
|
||||
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Box, Dict
|
||||
>>> import numpy as np
|
||||
>>> space = Dict({
|
||||
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
|
||||
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
|
||||
>>> create_empty_array(space, n=2, fn=np.zeros)
|
||||
OrderedDict([('position', array([[0., 0., 0.],
|
||||
[0., 0., 0.]], dtype=float32)), ('velocity', array([[0., 0.],
|
||||
[0., 0.]], dtype=float32))])
|
||||
"""
|
||||
raise ValueError(
|
||||
f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance."
|
||||
|
@@ -27,17 +27,6 @@ __all__ = ["BaseGymSpaces", "_BaseGymSpaces", "batch_space", "iterate"]
|
||||
def batch_space(space: Space, n: int = 1) -> Space:
|
||||
"""Create a (batched) space, containing multiple copies of a single space.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from gymnasium.spaces import Box, Dict
|
||||
>>> import numpy as np
|
||||
>>> space = Dict({
|
||||
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
|
||||
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)
|
||||
... })
|
||||
>>> batch_space(space, n=5)
|
||||
Dict('position': Box(0.0, 1.0, (5, 3), float32), 'velocity': Box(0.0, 1.0, (5, 2), float32))
|
||||
|
||||
Args:
|
||||
space: Space (e.g. the observation space) for a single environment in the vectorized environment.
|
||||
n: Number of environments in the vectorized environment.
|
||||
@@ -47,6 +36,16 @@ def batch_space(space: Space, n: int = 1) -> Space:
|
||||
|
||||
Raises:
|
||||
ValueError: Cannot batch space that is not a valid :class:`gym.Space` instance
|
||||
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Box, Dict
|
||||
>>> import numpy as np
|
||||
>>> space = Dict({
|
||||
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
|
||||
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)
|
||||
... })
|
||||
>>> batch_space(space, n=5)
|
||||
Dict('position': Box(0.0, 1.0, (5, 3), float32), 'velocity': Box(0.0, 1.0, (5, 2), float32))
|
||||
"""
|
||||
raise ValueError(
|
||||
f"Cannot batch space with type `{type(space)}`. The space must be a valid `gymnasium.Space` instance."
|
||||
@@ -138,8 +137,17 @@ def _batch_space_custom(space, n=1):
|
||||
def iterate(space: Space, items) -> Iterator:
|
||||
"""Iterate over the elements of a (batched) space.
|
||||
|
||||
Example::
|
||||
Args:
|
||||
space: Space to which `items` belong to.
|
||||
items: Items to be iterated over.
|
||||
|
||||
Returns:
|
||||
Iterator over the elements in `items`.
|
||||
|
||||
Raises:
|
||||
ValueError: Space is not an instance of :class:`gym.Space`
|
||||
|
||||
Example:
|
||||
>>> from gymnasium.spaces import Box, Dict
|
||||
>>> import numpy as np
|
||||
>>> space = Dict({
|
||||
@@ -151,18 +159,10 @@ def iterate(space: Space, items) -> Iterator:
|
||||
OrderedDict([('position', array([0.77395606, 0.43887845, 0.85859793], dtype=float32)), ('velocity', array([0.77395606, 0.43887845], dtype=float32))])
|
||||
>>> next(it)
|
||||
OrderedDict([('position', array([0.697368 , 0.09417735, 0.97562236], dtype=float32)), ('velocity', array([0.85859793, 0.697368 ], dtype=float32))])
|
||||
>>> next(it) # doctest: +SKIP
|
||||
>>> next(it)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
StopIteration
|
||||
|
||||
Args:
|
||||
space: Space to which `items` belong to.
|
||||
items: Items to be iterated over.
|
||||
|
||||
Returns:
|
||||
Iterator over the elements in `items`.
|
||||
|
||||
Raises:
|
||||
ValueError: Space is not an instance of :class:`gym.Space`
|
||||
"""
|
||||
raise ValueError(
|
||||
f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance."
|
||||
|
@@ -125,8 +125,7 @@ class VectorEnv(gym.Env):
|
||||
Returns:
|
||||
A batch of observations and info from the vectorized environment.
|
||||
|
||||
An example::
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> envs = gym.vector.make("CartPole-v1", num_envs=3)
|
||||
>>> envs.reset(seed=42)
|
||||
@@ -134,7 +133,6 @@ class VectorEnv(gym.Env):
|
||||
[ 0.01522993, -0.04562247, -0.04799704, 0.03392126],
|
||||
[-0.03774345, -0.02418869, -0.00942293, 0.0469184 ]],
|
||||
dtype=float32), {})
|
||||
|
||||
"""
|
||||
self.reset_async(seed=seed, options=options)
|
||||
return self.reset_wait(seed=seed, options=options)
|
||||
@@ -174,8 +172,9 @@ class VectorEnv(gym.Env):
|
||||
the returned observation and info is not the final step's observation or info which is instead stored in
|
||||
info as `"final_observation"` and `"final_info"`.
|
||||
|
||||
An example::
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> import numpy as np
|
||||
>>> envs = gym.vector.make("CartPole-v1", num_envs=3)
|
||||
>>> _ = envs.reset(seed=42)
|
||||
>>> actions = np.array([1, 0, 1])
|
||||
|
@@ -24,23 +24,21 @@ class HumanRendering(gym.Wrapper):
|
||||
>>> wrapped = HumanRendering(env)
|
||||
>>> obs, _ = wrapped.reset() # This will start rendering to the screen
|
||||
|
||||
The wrapper can also be applied directly when the environment is instantiated, simply by passing
|
||||
``render_mode="human"`` to ``make``. The wrapper will only be applied if the environment does not
|
||||
implement human-rendering natively (i.e. ``render_mode`` does not contain ``"human"``).
|
||||
The wrapper can also be applied directly when the environment is instantiated, simply by passing
|
||||
``render_mode="human"`` to ``make``. The wrapper will only be applied if the environment does not
|
||||
implement human-rendering natively (i.e. ``render_mode`` does not contain ``"human"``).
|
||||
|
||||
Example:
|
||||
>>> env = gym.make("CartPoleJax-v1", render_mode="human") # CartPoleJax-v1 doesn't implement human-rendering natively
|
||||
>>> obs, _ = env.reset() # This will start rendering to the screen
|
||||
|
||||
Warning: If the base environment uses ``render_mode="rgb_array_list"``, its (i.e. the *base environment's*) render method
|
||||
Warning: If the base environment uses ``render_mode="rgb_array_list"``, its (i.e. the *base environment's*) render method
|
||||
will always return an empty list:
|
||||
|
||||
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array_list")
|
||||
>>> wrapped = HumanRendering(env)
|
||||
>>> obs, _ = wrapped.reset()
|
||||
>>> env.render() # env.render() will always return an empty list!
|
||||
[]
|
||||
|
||||
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array_list")
|
||||
>>> wrapped = HumanRendering(env)
|
||||
>>> obs, _ = wrapped.reset()
|
||||
>>> env.render() # env.render() will always return an empty list!
|
||||
[]
|
||||
"""
|
||||
|
||||
def __init__(self, env):
|
||||
|
@@ -11,13 +11,18 @@ class OrderEnforcing(gym.Wrapper):
|
||||
>>> from gymnasium.wrappers import OrderEnforcing
|
||||
>>> env = gym.make("CartPole-v1", render_mode="human")
|
||||
>>> env = OrderEnforcing(env)
|
||||
>>> env.step(0) # doctest: +SKIP
|
||||
>>> env.step(0)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
gymnasium.error.ResetNeeded: Cannot call env.step() before calling env.reset()
|
||||
>>> env.render() # doctest: +SKIP
|
||||
gymnasium.error.ResetNeeded('Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper.')
|
||||
>>> env.render()
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
gymnasium.error.ResetNeeded: Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper.
|
||||
>>> _ = env.reset()
|
||||
>>> env.render()
|
||||
>>> _ = env.step(0)
|
||||
>>> env.close()
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False):
|
||||
|
@@ -15,7 +15,7 @@ class StepAPICompatibility(gym.Wrapper):
|
||||
env (gym.Env): the env to wrap. Can be in old or new API
|
||||
output_truncation_bool (bool): Apply to convert environment to use new step API that returns two bool. (True by default)
|
||||
|
||||
Examples:
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> from gymnasium.wrappers import StepAPICompatibility
|
||||
>>> env = gym.make("CartPole-v1")
|
||||
@@ -24,7 +24,6 @@ class StepAPICompatibility(gym.Wrapper):
|
||||
>>> env = StepAPICompatibility(gym.make("CartPole-v1"))
|
||||
>>> env
|
||||
<StepAPICompatibility<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>>
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, output_truncation_bool: bool = True):
|
||||
|
@@ -17,7 +17,7 @@ class VectorListInfo(gym.Wrapper):
|
||||
|
||||
i.e. `VectorListInfo(RecordEpisodeStatistics(envs))`
|
||||
|
||||
Example::
|
||||
Example:
|
||||
>>> # As dict:
|
||||
>>> infos = {
|
||||
... "final_observation": "<array of length num-envs>",
|
||||
|
Reference in New Issue
Block a user