Add check doctest to CI and fixed existing errors (#274)

This commit is contained in:
Valentin
2023-01-20 14:28:09 +01:00
committed by GitHub
parent 1551b89257
commit b4caf9df16
45 changed files with 321 additions and 222 deletions

View File

@@ -18,6 +18,8 @@ jobs:
--tag gymnasium-all-docker . --tag gymnasium-all-docker .
- name: Run tests - name: Run tests
run: docker run gymnasium-all-docker pytest tests/* run: docker run gymnasium-all-docker pytest tests/*
- name: Run doctest
run: docker run gymnasium-all-docker pytest --doctest-modules gymnasium/
build-necessary: build-necessary:
runs-on: runs-on:

View File

@@ -27,25 +27,50 @@ class CartPoleFunctional(
>>> import jax >>> import jax
>>> import jax.numpy as jnp >>> import jax.numpy as jnp
>>> from gymnasium.envs.phys2d.cartpole import CartPoleFunctional
>>> key = jax.random.PRNGKey(0) >>> key = jax.random.PRNGKey(0)
>>> env = CartPole({"x_init": 0.5}) >>> env = CartPoleFunctional({"x_init": 0.5})
>>> state = env.initial(key) >>> state = env.initial(key)
>>> print(state) >>> print(state)
>>> print(env.step(state, 0)) [ 0.46532142 -0.27484107 0.13302994 -0.20361817]
>>> print(env.transition(state, 0))
[ 0.4598246 -0.6357784 0.12895757 0.1278053 ]
>>> env.transform(jax.jit) >>> env.transform(jax.jit)
>>> state = env.initial(key) >>> state = env.initial(key)
>>> print(state) >>> print(state)
>>> print(env.step(state, 0)) [ 0.46532142 -0.27484107 0.13302994 -0.20361817]
>>> print(env.transition(state, 0))
[ 0.4598246 -0.6357784 0.12895757 0.12780523]
>>> vkey = jax.random.split(key, 10) >>> vkey = jax.random.split(key, 10)
>>> env.transform(jax.vmap) >>> env.transform(jax.vmap)
>>> vstate = env.initial(vkey) >>> vstate = env.initial(vkey)
>>> print(vstate) >>> print(vstate)
>>> print(env.step(vstate, jnp.array([0 for _ in range(10)]))) [[ 0.25117755 -0.03159595 0.09428263 0.12404168]
[ 0.231457 0.41420317 -0.13484478 0.29151905]
[-0.11706758 -0.37130308 0.13587534 0.33141208]
[-0.4613737 0.36557996 0.3950702 0.3639989 ]
[-0.14707637 -0.34273267 -0.32374108 -0.48110402]
[-0.45774353 0.3633288 -0.3157575 -0.03586268]
[ 0.37344885 -0.279778 -0.33894253 0.07415426]
[-0.20234215 0.39775252 -0.2556088 0.32877135]
[-0.2572986 -0.29943776 -0.45600426 -0.35740316]
[ 0.05436695 0.35021234 -0.36484408 0.2805779 ]]
>>> print(env.transition(vstate, jnp.array([0 for _ in range(10)])))
[[ 0.25054562 -0.38763174 0.09676346 0.4448946 ]
[ 0.23974106 0.09849604 -0.1290144 0.5390002 ]
[-0.12449364 -0.7323911 0.14250359 0.6634313 ]
[-0.45406207 -0.01028753 0.4023502 0.7505522 ]
[-0.15393102 -0.6168968 -0.33336315 -0.30407968]
[-0.45047694 0.08870795 -0.31647477 0.14311607]
[ 0.36785328 -0.54895645 -0.33745944 0.24393772]
[-0.19438711 0.10855066 -0.24903338 0.5316877 ]
[-0.26328734 -0.5420943 -0.46315232 -0.2344252 ]
[ 0.06137119 0.08665388 -0.35923252 0.4403924 ]]
""" """
gravity = 9.8 gravity = 9.8

View File

@@ -121,16 +121,17 @@ class OrderEnforcingV0(gym.Wrapper):
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`. """A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
Example: Example:
>>> from gymnasium.envs.classic_control import CartPoleEnv >>> import gymnasium as gym
>>> env = CartPoleEnv() >>> from gymnasium.experimental.wrappers import OrderEnforcingV0
>>> env = gym.make("CartPole-v1", render_mode="human")
>>> env = OrderEnforcingV0(env) >>> env = OrderEnforcingV0(env)
>>> env.step(0) >>> env.step(0) # doctest: +SKIP
ResetNeeded: Cannot call env.step() before calling env.reset() gymnasium.error.ResetNeeded: Cannot call env.step() before calling env.reset()
>>> env.render() # doctest: +SKIP
gymnasium.error.ResetNeeded('Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper.')
>>> _ = env.reset()
>>> env.render() >>> env.render()
ResetNeeded: Cannot call env.render() before calling env.reset() >>> _ = env.step(0)
>>> env.reset()
>>> env.render()
>>> env.step(0)
""" """
def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False): def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False):
@@ -185,7 +186,6 @@ class RecordEpisodeStatisticsV0(gym.Wrapper):
After the completion of an episode, ``info`` will look like this:: After the completion of an episode, ``info`` will look like this::
>>> info = { >>> info = {
... ...
... "episode": { ... "episode": {
... "r": "<cumulative reward>", ... "r": "<cumulative reward>",
... "l": "<episode length>", ... "l": "<episode length>",
@@ -196,7 +196,10 @@ class RecordEpisodeStatisticsV0(gym.Wrapper):
For a vectorized environments the output will be in the form of:: For a vectorized environments the output will be in the form of::
>>> infos = { >>> infos = {
... ... ... "final_observation": "<array of length num-envs>",
... "_final_observation": "<boolean array of length num-envs>",
... "final_info": "<array of length num-envs>",
... "_final_info": "<boolean array of length num-envs>",
... "episode": { ... "episode": {
... "r": "<array of cumulative reward>", ... "r": "<array of cumulative reward>",
... "l": "<array of episode length>", ... "l": "<array of episode length>",
@@ -205,6 +208,7 @@ class RecordEpisodeStatisticsV0(gym.Wrapper):
... "_episode": "<boolean array of length num-envs>" ... "_episode": "<boolean array of length num-envs>"
... } ... }
Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via
:attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively. :attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively.

View File

@@ -52,13 +52,15 @@ class ClipActionV0(LambdaActionV0):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import ClipActionV0
>>> import numpy as np >>> import numpy as np
>>> env = gym.make('BipedalWalker-v3', disable_env_checker=True) >>> env = gym.make("Hopper-v4", disable_env_checker=True)
>>> env = ClipActionV0(env) >>> env = ClipActionV0(env)
>>> env.action_space >>> env.action_space
Box(-1.0, 1.0, (4,), float32) Box(-inf, inf, (3,), float32)
>>> env.step(np.array([5.0, 2.0, -10.0, 0.0])) >>> _ = env.reset(seed=42)
# Executes the action np.array([1.0, 1.0, -1.0, 0]) in the base environment >>> _ = env.step(np.array([5.0, -2.0, 0.0]))
... # Executes the action np.array([1.0, -1.0, 0]) in the base environment
""" """
def __init__(self, env: gym.Env): def __init__(self, env: gym.Env):
@@ -89,13 +91,14 @@ class RescaleActionV0(LambdaActionV0):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import RescaleActionV0
>>> import numpy as np >>> import numpy as np
>>> env = gym.make('BipedalWalker-v3', disable_env_checker=True) >>> env = gym.make("Hopper-v4", disable_env_checker=True)
>>> _ = env.reset(seed=42) >>> _ = env.reset(seed=42)
>>> obs, _, _, _, _ = env.step(np.array([1,1,1,1])) >>> obs, _, _, _, _ = env.step(np.array([1,1,1]))
>>> _ = env.reset(seed=42) >>> _ = env.reset(seed=42)
>>> min_action = -0.5 >>> min_action = -0.5
>>> max_action = np.array([0.0, 0.5, 1.0, 0.75]) >>> max_action = np.array([0.0, 0.5, 0.75])
>>> wrapped_env = RescaleActionV0(env, min_action=min_action, max_action=max_action) >>> wrapped_env = RescaleActionV0(env, min_action=min_action, max_action=max_action)
>>> wrapped_env_obs, _, _, _, _ = wrapped_env.step(max_action) >>> wrapped_env_obs, _, _, _, _ = wrapped_env.step(max_action)
>>> np.alltrue(obs == wrapped_env_obs) >>> np.alltrue(obs == wrapped_env_obs)
@@ -122,7 +125,7 @@ class RescaleActionV0(LambdaActionV0):
if not isinstance(min_action, np.ndarray): if not isinstance(min_action, np.ndarray):
assert np.issubdtype(type(min_action), np.integer) or np.issubdtype( assert np.issubdtype(type(min_action), np.integer) or np.issubdtype(
type(max_action), np.floating type(min_action), np.floating
) )
min_action = np.full(env.action_space.shape, min_action) min_action = np.full(env.action_space.shape, min_action)

View File

@@ -39,11 +39,13 @@ class LambdaObservationV0(gym.ObservationWrapper):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import LambdaObservationV0
>>> import numpy as np >>> import numpy as np
>>> env = gym.make('CartPole-v1') >>> np.random.seed(0)
>>> env = LambdaObservationV0(env, lambda obs: obs + 0.1 * np.random.random(obs.shape)) >>> env = gym.make("CartPole-v1")
>>> env.reset() >>> env = LambdaObservationV0(env, lambda obs: obs + 0.1 * np.random.random(obs.shape), env.observation_space)
array([-0.08319338, 0.04635121, -0.07394746, 0.20877492]) >>> env.reset(seed=42) # doctest: +SKIP
(array([ 0.06199517, 0.0511615 , -0.04432538, 0.02694618]), {})
""" """
def __init__( def __init__(
@@ -75,17 +77,18 @@ class FilterObservationV0(LambdaObservationV0):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> env = gym.wrappers.TransformObservation( >>> from gymnasium.wrappers import TransformObservation
... gym.make('CartPole-v1'), lambda obs: {'obs': obs, 'time': 0} >>> from gymnasium.experimental.wrappers import FilterObservationV0
... ) >>> env = gym.make("CartPole-v1")
>>> env = gym.wrappers.TransformObservation(env, lambda obs: {'obs': obs, 'time': 0})
>>> env.observation_space = gym.spaces.Dict(obs=env.observation_space, time=gym.spaces.Discrete(1)) >>> env.observation_space = gym.spaces.Dict(obs=env.observation_space, time=gym.spaces.Discrete(1))
>>> env.reset() >>> env.reset(seed=42)
{'obs': array([-0.00067088, -0.01860439, 0.04772898, -0.01911527], dtype=float32), 'time': 0} ({'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), 'time': 0}, {})
>>> env = FilterObservationV0(env, filter_keys=['time']) >>> env = FilterObservationV0(env, filter_keys=['time'])
>>> env.reset() >>> env.reset(seed=42)
{'obs': array([ 0.04560107, 0.04466959, -0.0328232 , -0.02367178], dtype=float32)} ({'time': 0}, {})
>>> env.step(0) >>> env.step(0)
({'obs': array([ 0.04649447, -0.14996664, -0.03329664, 0.25847703], dtype=float32)}, 1.0, False, {}) ({'time': 0}, 1.0, False, False, {})
""" """
def __init__(self, env: gym.Env, filter_keys: Sequence[str | int]): def __init__(self, env: gym.Env, filter_keys: Sequence[str | int]):
@@ -171,13 +174,14 @@ class FlattenObservationV0(LambdaObservationV0):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> env = gym.make('CarRacing-v1') >>> from gymnasium.experimental.wrappers import FlattenObservationV0
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space.shape >>> env.observation_space.shape
(96, 96, 3) (96, 96, 3)
>>> env = FlattenObservationV0(env) >>> env = FlattenObservationV0(env)
>>> env.observation_space.shape >>> env.observation_space.shape
(27648,) (27648,)
>>> obs, info = env.reset() >>> obs, _ = env.reset()
>>> obs.shape >>> obs.shape
(27648,) (27648,)
""" """
@@ -198,7 +202,8 @@ class GrayscaleObservationV0(LambdaObservationV0):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> env = gym.make("CarRacing-v1") >>> from gymnasium.experimental.wrappers import GrayscaleObservationV0
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space.shape >>> env.observation_space.shape
(96, 96, 3) (96, 96, 3)
>>> grayscale_env = GrayscaleObservationV0(env) >>> grayscale_env = GrayscaleObservationV0(env)
@@ -258,6 +263,7 @@ class ResizeObservationV0(LambdaObservationV0):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import ResizeObservationV0
>>> env = gym.make("CarRacing-v2") >>> env = gym.make("CarRacing-v2")
>>> env.observation_space.shape >>> env.observation_space.shape
(96, 96, 3) (96, 96, 3)
@@ -303,7 +309,8 @@ class ReshapeObservationV0(LambdaObservationV0):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> env = gym.make("CarRacing-v1") >>> from gymnasium.experimental.wrappers import ReshapeObservationV0
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space.shape >>> env.observation_space.shape
(96, 96, 3) (96, 96, 3)
>>> reshape_env = ReshapeObservationV0(env, (24, 4, 96, 1, 3)) >>> reshape_env = ReshapeObservationV0(env, (24, 4, 96, 1, 3))
@@ -335,11 +342,14 @@ class RescaleObservationV0(LambdaObservationV0):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import RescaleObservationV0
>>> env = gym.make("Pendulum-v1") >>> env = gym.make("Pendulum-v1")
>>> env.observation_space >>> env.observation_space
Box([-1. -1. -8.], [1. 1. 8.], (3,), float32) Box([-1. -1. -8.], [1. 1. 8.], (3,), float32)
>>> env = RescaleObservationV0(env, np.array([-2, -1, -10]), np.array([1, 0, 1])) >>> env = RescaleObservationV0(env, np.array([-2, -1, -10]), np.array([1, 0, 1]))
Box([-2. -1. -10.], [1. 0. 1.], (3,), float32) >>> env.observation_space
Box([ -2. -1. -10.], [1. 0. 1.], (3,), float32)
""" """
def __init__( def __init__(

View File

@@ -62,7 +62,7 @@ class ClipRewardV0(LambdaRewardV0):
>>> from gymnasium.experimental.wrappers import ClipRewardV0 >>> from gymnasium.experimental.wrappers import ClipRewardV0
>>> env = gym.make("CartPole-v1") >>> env = gym.make("CartPole-v1")
>>> env = ClipRewardV0(env, 0, 0.5) >>> env = ClipRewardV0(env, 0, 0.5)
>>> env.reset() >>> _ = env.reset()
>>> _, rew, _, _, _ = env.step(1) >>> _, rew, _, _, _ = env.step(1)
>>> rew >>> rew
0.5 0.5

View File

@@ -288,26 +288,28 @@ class HumanRenderingV0(gym.Wrapper):
The ``render_mode`` of the wrapped environment must be either ``'rgb_array'`` or ``'rgb_array_list'``. The ``render_mode`` of the wrapped environment must be either ``'rgb_array'`` or ``'rgb_array_list'``.
Example: Example:
>>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import HumanRenderingV0
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array") >>> env = gym.make("LunarLander-v2", render_mode="rgb_array")
>>> wrapped = HumanRenderingV0(env) >>> wrapped = HumanRenderingV0(env)
>>> wrapped.reset() # This will start rendering to the screen >>> obs, _ = wrapped.reset() # This will start rendering to the screen
The wrapper can also be applied directly when the environment is instantiated, simply by passing The wrapper can also be applied directly when the environment is instantiated, simply by passing
``render_mode="human"`` to ``make``. The wrapper will only be applied if the environment does not ``render_mode="human"`` to ``make``. The wrapper will only be applied if the environment does not
implement human-rendering natively (i.e. ``render_mode`` does not contain ``"human"``). implement human-rendering natively (i.e. ``render_mode`` does not contain ``"human"``).
Example: Example:
>>> env = gym.make("NoNativeRendering-v2", render_mode="human") # NoNativeRendering-v0 doesn't implement human-rendering natively >>> env = gym.make("CartPoleJax-v1", render_mode="human") # CartPoleJax-v1 doesn't implement human-rendering natively
>>> env.reset() # This will start rendering to the screen >>> obs, _ = env.reset() # This will start rendering to the screen
Warning: If the base environment uses ``render_mode="rgb_array_list"``, its (i.e. the *base environment's*) render method Warning: If the base environment uses ``render_mode="rgb_array_list"``, its (i.e. the *base environment's*) render method
will always return an empty list: will always return an empty list:
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array_list") >>> env = gym.make("LunarLander-v2", render_mode="rgb_array_list")
>>> wrapped = HumanRenderingV0(env) >>> wrapped = HumanRenderingV0(env)
>>> wrapped.reset() >>> obs, _ = wrapped.reset()
>>> env.render() >>> env.render() # env.render() will always return an empty list!
[] # env.render() will always return an empty list! []
""" """

View File

@@ -80,24 +80,27 @@ class TimeAwareObservationV0(gym.ObservationWrapper):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> from gymnasium.experimental.wrappers import TimeAwareObservationV0 >>> from gymnasium.experimental.wrappers import TimeAwareObservationV0
>>> env = gym.make('CartPole-v1') >>> env = gym.make("CartPole-v1")
>>> env = TimeAwareObservationV0(env) >>> env = TimeAwareObservationV0(env)
>>> env.observation_space >>> env.observation_space
Dict(obs: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), time: Box(0.0, 500, (1,), float32)) Dict('obs': Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), 'time': Box(0.0, 1.0, (1,), float32))
>>> _ = env.reset() >>> _ = env.reset(seed=42)
>>> _ = env.action_space.seed(42)
>>> env.step(env.action_space.sample())[0] >>> env.step(env.action_space.sample())[0]
OrderedDict([('obs', {'obs': array([ 0.02727336, -0.20172954, 0.03625453, 0.32351476], dtype=float32), 'time': 0.002}
... array([ 0.02866629, 0.2310988 , -0.02614601, -0.2600732 ], dtype=float32)),
... ('time', array([0.002]))])
Flatten observation space example: Flatten observation space example:
>>> env = gym.make('CartPole-v1') >>> env = gym.make("CartPole-v1")
>>> env = TimeAwareObservationV0(env, flatten=True) >>> env = TimeAwareObservationV0(env, flatten=True)
>>> env.observation_space >>> env.observation_space
Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38 0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 500], (5,), float32) Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38
>>> _ = env.reset() 0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 1.0000000e+00], (5,), float32)
>>> _ = env.reset(seed=42)
>>> _ = env.action_space.seed(42)
>>> env.step(env.action_space.sample())[0] >>> env.step(env.action_space.sample())[0]
array([-0.01232257, 0.19335455, -0.02244143, -0.32388705, 0.002 ], dtype=float32) array([ 0.02727336, -0.20172954, 0.03625453, 0.32351476, 0.002 ],
dtype=float32)
""" """
def __init__( def __init__(
@@ -224,11 +227,12 @@ class FrameStackObservationV0(gym.Wrapper):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> env = gym.make('CarRacing-v1') >>> from gymnasium.experimental.wrappers import FrameStackObservationV0
>>> env = FrameStack(env, 4) >>> env = gym.make("CarRacing-v2")
>>> env = FrameStackObservationV0(env, 4)
>>> env.observation_space >>> env.observation_space
Box(4, 96, 96, 3) Box(0, 255, (4, 96, 96, 3), uint8)
>>> obs = env.reset() >>> obs, _ = env.reset()
>>> obs.shape >>> obs.shape
(4, 96, 96, 3) (4, 96, 96, 3)
""" """

View File

@@ -44,12 +44,12 @@ class Box(Space[NDArray[Any]]):
* Identical bound for each dimension:: * Identical bound for each dimension::
>>> Box(low=-1.0, high=2.0, shape=(3, 4), dtype=np.float32) >>> Box(low=-1.0, high=2.0, shape=(3, 4), dtype=np.float32)
Box(3, 4) Box(-1.0, 2.0, (3, 4), float32)
* Independent bound for each dimension:: * Independent bound for each dimension::
>>> Box(low=np.array([-1.0, -2.0]), high=np.array([2.0, 4.0]), dtype=np.float32) >>> Box(low=np.array([-1.0, -2.0]), high=np.array([2.0, 4.0]), dtype=np.float32)
Box(2,) Box([-1. -2.], [2. 4.], (2,), float32)
""" """
def __init__( def __init__(

View File

@@ -19,14 +19,14 @@ class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
Example usage: Example usage:
>>> from gymnasium.spaces import Dict, Discrete >>> from gymnasium.spaces import Dict, Discrete
>>> observation_space = Dict({"position": Discrete(2), "velocity": Discrete(3)}) >>> observation_space = Dict({"position": Discrete(2), "velocity": Discrete(3)}, seed=42)
>>> observation_space.sample() >>> observation_space.sample()
OrderedDict([('position', 1), ('velocity', 2)]) OrderedDict([('position', 0), ('velocity', 2)])
Example usage [nested]:: Example usage [nested]::
>>> from gymnasium.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete >>> from gymnasium.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete
>>> Dict( >>> Dict( # doctest: +SKIP
... { ... {
... "ext_controller": MultiDiscrete([5, 2, 2]), ... "ext_controller": MultiDiscrete([5, 2, 2]),
... "inner_state": Dict( ... "inner_state": Dict(
@@ -66,9 +66,9 @@ class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]):
>>> from gymnasium.spaces import Box, Discrete >>> from gymnasium.spaces import Box, Discrete
>>> Dict({"position": Box(-1, 1, shape=(2,)), "color": Discrete(3)}) >>> Dict({"position": Box(-1, 1, shape=(2,)), "color": Discrete(3)})
Dict(color:Discrete(3), position:Box(-1.0, 1.0, (2,), float32)) Dict('color': Discrete(3), 'position': Box(-1.0, 1.0, (2,), float32))
>>> Dict(position=Box(-1, 1, shape=(2,)), color=Discrete(3)) >>> Dict(position=Box(-1, 1, shape=(2,)), color=Discrete(3))
Dict(color:Discrete(3), position:Box(-1.0, 1.0, (2,), float32)) Dict('position': Box(-1.0, 1.0, (2,), float32), 'color': Discrete(3))
Args: Args:
spaces: A dictionary of spaces. This specifies the structure of the :class:`Dict` space spaces: A dictionary of spaces. This specifies the structure of the :class:`Dict` space

View File

@@ -16,7 +16,9 @@ class Discrete(Space[np.int64]):
Example:: Example::
>>> Discrete(2) # {0, 1} >>> Discrete(2) # {0, 1}
Discrete(2)
>>> Discrete(3, start=-1) # {-1, 0, 1} >>> Discrete(3, start=-1) # {-1, 0, 1}
Discrete(3, start=-1)
""" """
def __init__( def __init__(

View File

@@ -16,13 +16,13 @@ class MultiBinary(Space[npt.NDArray[np.int8]]):
Example Usage:: Example Usage::
>>> observation_space = MultiBinary(5) >>> observation_space = MultiBinary(5, seed=42)
>>> observation_space.sample() >>> observation_space.sample()
array([0, 1, 0, 1, 0], dtype=int8) array([1, 0, 1, 0, 1], dtype=int8)
>>> observation_space = MultiBinary([3, 2]) >>> observation_space = MultiBinary([3, 2], seed=42)
>>> observation_space.sample() >>> observation_space.sample()
array([[0, 0], array([[1, 0],
[0, 1], [1, 0],
[1, 1]], dtype=int8) [1, 1]], dtype=int8)
""" """

View File

@@ -32,7 +32,7 @@ class MultiDiscrete(Space[npt.NDArray[np.integer]]):
Example:: Example::
>> d = MultiDiscrete(np.array([[1, 2], [3, 4]])) >> d = MultiDiscrete(np.array([[1, 2], [3, 4]]), seed=42)
>> d.sample() >> d.sample()
array([[0, 0], array([[0, 0],
[2, 3]]) [2, 3]])

View File

@@ -19,11 +19,11 @@ class Sequence(Space[typing.Tuple[Any, ...]]):
Example:: Example::
>>> from gymnasium.spaces import Box >>> from gymnasium.spaces import Box
>>> space = Sequence(Box(0, 1)) >>> space = Sequence(Box(0, 1), seed=42)
>>> space.sample() >>> space.sample() # doctest: +SKIP
(array([0.0259352], dtype=float32),) (array([0.6369617], dtype=float32),)
>>> space.sample() >>> space.sample() # doctest: +SKIP
(array([0.80977976], dtype=float32), array([0.80066574], dtype=float32), array([0.77165383], dtype=float32)) (array([0.01652764], dtype=float32), array([0.8132702], dtype=float32),)
""" """
def __init__( def __init__(

View File

@@ -20,11 +20,13 @@ class Text(Space[str]):
Example:: Example::
>>> # {"", "B5", "hello", ...} >>> # {"", "B5", "hello", ...}
>>> Text(5) >>> Text(5)
Text(1, 5, characters=0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz)
>>> # {"0", "42", "0123456789", ...} >>> # {"0", "42", "0123456789", ...}
>>> import string >>> import string
>>> Text(min_length = 1, >>> Text(min_length = 1,
... max_length = 10, ... max_length = 10,
... charset = string.digits) ... charset = string.digits)
Text(1, 10, characters=0123456789)
""" """
def __init__( def __init__(

View File

@@ -18,9 +18,9 @@ class Tuple(Space[typing.Tuple[Any, ...]], typing.Sequence[Any]):
Example usage:: Example usage::
>>> from gymnasium.spaces import Box, Discrete >>> from gymnasium.spaces import Box, Discrete
>>> observation_space = Tuple((Discrete(2), Box(-1, 1, shape=(2,)))) >>> observation_space = Tuple((Discrete(2), Box(-1, 1, shape=(2,))), seed=42)
>>> observation_space.sample() >>> observation_space.sample()
(0, array([0.03633198, 0.42370757], dtype=float32)) (0, array([-0.3991573 , 0.21649833], dtype=float32))
""" """
def __init__( def __init__(

View File

@@ -392,9 +392,9 @@ def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
>>> from gymnasium.spaces import Box >>> from gymnasium.spaces import Box
>>> box = Box(0.0, 1.0, shape=(3, 4, 5)) >>> box = Box(0.0, 1.0, shape=(3, 4, 5))
>>> box >>> box
Box(3, 4, 5) Box(0.0, 1.0, (3, 4, 5), float32)
>>> flatten_space(box) >>> flatten_space(box)
Box(60,) Box(0.0, 1.0, (60,), float32)
>>> flatten(box, box.sample()) in flatten_space(box) >>> flatten(box, box.sample()) in flatten_space(box)
True True
@@ -402,7 +402,7 @@ def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
>>> from gymnasium.spaces import Discrete >>> from gymnasium.spaces import Discrete
>>> discrete = Discrete(5) >>> discrete = Discrete(5)
>>> flatten_space(discrete) >>> flatten_space(discrete)
Box(5,) Box(0, 1, (5,), int64)
>>> flatten(box, box.sample()) in flatten_space(box) >>> flatten(box, box.sample()) in flatten_space(box)
True True
@@ -410,7 +410,7 @@ def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph:
>>> from gymnasium.spaces import Dict, Discrete, Box >>> from gymnasium.spaces import Dict, Discrete, Box
>>> space = Dict({"position": Discrete(2), "velocity": Box(0, 1, shape=(2, 2))}) >>> space = Dict({"position": Discrete(2), "velocity": Box(0, 1, shape=(2, 2))})
>>> flatten_space(space) >>> flatten_space(space)
Box(6,) Box(0.0, 1.0, (6,), float64)
>>> flatten(space, space.sample()) in flatten_space(space) >>> flatten(space, space.sample()) in flatten_space(space)
True True

View File

@@ -5,8 +5,7 @@ class EzPickle:
"""Objects that are pickled and unpickled via their constructor arguments. """Objects that are pickled and unpickled via their constructor arguments.
Example:: Example::
>>> class Dog(Animal, EzPickle): # doctest: +SKIP
>>> class Dog(Animal, EzPickle):
... def __init__(self, furcolor, tailkind="bushy"): ... def __init__(self, furcolor, tailkind="bushy"):
... Animal.__init__() ... Animal.__init__()
... EzPickle.__init__(self, furcolor, tailkind) ... EzPickle.__init__(self, furcolor, tailkind)

View File

@@ -161,7 +161,7 @@ def play(
>>> import gymnasium as gym >>> import gymnasium as gym
>>> from gymnasium.utils.play import play >>> from gymnasium.utils.play import play
>>> play(gym.make("CarRacing-v1", render_mode="rgb_array"), keys_to_action={ >>> play(gym.make("CarRacing-v2", render_mode="rgb_array"), keys_to_action={ # doctest: +SKIP
... "w": np.array([0, 0.7, 0]), ... "w": np.array([0, 0.7, 0]),
... "a": np.array([-1, 0, 0]), ... "a": np.array([-1, 0, 0]),
... "s": np.array([0, 0, 1]), ... "s": np.array([0, 0, 1]),
@@ -181,10 +181,11 @@ def play(
for last 150 steps. for last 150 steps.
>>> import gymnasium as gym >>> import gymnasium as gym
>>> from gymnasium.utils.play import PlayPlot, play
>>> def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info): >>> def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info):
... return [rew,] ... return [rew,]
>>> plotter = PlayPlot(callback, 150, ["reward"]) >>> plotter = PlayPlot(callback, 150, ["reward"]) # doctest: +SKIP
>>> play(gym.make("CartPole-v1"), callback=plotter.callback) >>> play(gym.make("CartPole-v1"), callback=plotter.callback) # doctest: +SKIP
Args: Args:
env: Environment to use for playing. env: Environment to use for playing.
@@ -207,7 +208,7 @@ def play(
For example if pressing 'w' and space at the same time is supposed For example if pressing 'w' and space at the same time is supposed
to trigger action number 2 then ``key_to_action`` dict could look like this: to trigger action number 2 then ``key_to_action`` dict could look like this:
>>> { >>> key_to_action = {
... # ... ... # ...
... (ord('w'), ord(' ')): 2 ... (ord('w'), ord(' ')): 2
... # ... ... # ...
@@ -215,7 +216,7 @@ def play(
or like this: or like this:
>>> { >>> key_to_action = {
... # ... ... # ...
... ("w", " "): 2 ... ("w", " "): 2
... # ... ... # ...
@@ -223,7 +224,7 @@ def play(
or like this: or like this:
>>> { >>> key_to_action = {
... # ... ... # ...
... "w ": 2 ... "w ": 2
... # ... ... # ...
@@ -315,9 +316,9 @@ class PlayPlot:
Typically, this :meth:`callback` will be used in conjunction with :func:`play` to see how the metrics evolve as you play:: Typically, this :meth:`callback` will be used in conjunction with :func:`play` to see how the metrics evolve as you play::
>>> plotter = PlayPlot(compute_metrics, horizon_timesteps=200, >>> plotter = PlayPlot(compute_metrics, horizon_timesteps=200, # doctest: +SKIP
... plot_names=["Immediate Rew.", "Cumulative Rew.", "Action Magnitude"]) ... plot_names=["Immediate Rew.", "Cumulative Rew.", "Action Magnitude"])
>>> play(your_env, callback=plotter.callback) >>> play(your_env, callback=plotter.callback) # doctest: +SKIP
""" """
def __init__( def __init__(

View File

@@ -63,10 +63,10 @@ def save_video(
>>> import gymnasium as gym >>> import gymnasium as gym
>>> from gymnasium.utils.save_video import save_video >>> from gymnasium.utils.save_video import save_video
>>> env = gym.make("FrozenLake-v1", render_mode="rgb_array_list") >>> env = gym.make("FrozenLake-v1", render_mode="rgb_array_list")
>>> env.reset() >>> _ = env.reset()
>>> step_starting_index = 0 >>> step_starting_index = 0
>>> episode_index = 0 >>> episode_index = 0
>>> for step_index in range(199): >>> for step_index in range(199): # doctest: +SKIP
... action = env.action_space.sample() ... action = env.action_space.sample()
... _, _, terminated, truncated, _ = env.step(action) ... _, _, terminated, truncated, _ = env.step(action)
... ...

View File

@@ -153,12 +153,17 @@ def step_api_compatibility(
wrapper is written in new API, and the final step output is desired to be in old API. wrapper is written in new API, and the final step output is desired to be in old API.
>>> import gymnasium as gym >>> import gymnasium as gym
>>> env = gym.make("OldEnv") >>> env = gym.make("CartPole-v0")
>>> obs, rew, done, info = step_api_compatibility(env.step(action), output_truncation_bool=False) >>> _ = env.reset()
>>> obs, rew, terminated, truncated, info = step_api_compatibility(env.step(action), output_truncation_bool=True) >>> obs, rewards, done, info = step_api_compatibility(env.step(0), output_truncation_bool=False)
>>> obs, rewards, terminated, truncated, info = step_api_compatibility(env.step(0), output_truncation_bool=True)
>>> vec_env = gym.vector.make("CartPole-v0")
>>> _ = vec_env.reset()
>>> obs, rewards, dones, infos = step_api_compatibility(vec_env.step([0]), is_vector_env=True, output_truncation_bool=False)
>>> obs, rewards, terminated, truncated, info = step_api_compatibility(vec_env.step([0]), is_vector_env=True, output_truncation_bool=True)
>>> vec_env = gym.vector.make("OldEnv")
>>> observations, rewards, dones, infos = step_api_compatibility(vec_env.step(action), is_vector_env=True)
""" """
if output_truncation_bool: if output_truncation_bool:
return convert_to_terminated_truncated_step_api(step_returns, is_vector_env) return convert_to_terminated_truncated_step_api(step_returns, is_vector_env)

View File

@@ -25,11 +25,11 @@ def make(
>>> import gymnasium as gym >>> import gymnasium as gym
>>> env = gym.vector.make('CartPole-v1', num_envs=3) >>> env = gym.vector.make('CartPole-v1', num_envs=3)
>>> env.reset() >>> env.reset(seed=42)
array([[-0.04456399, 0.04653909, 0.01326909, -0.02099827], (array([[ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ],
[ 0.03073904, 0.00145001, -0.03088818, -0.03131252], [ 0.01522993, -0.04562247, -0.04799704, 0.03392126],
[ 0.03468829, 0.01500225, 0.01230312, 0.01825218]], [-0.03774345, -0.02418869, -0.00942293, 0.0469184 ]],
dtype=float32) dtype=float32), {})
Args: Args:
id: The environment ID. This must be a valid ID from the registry. id: The environment ID. This must be a valid ID from the registry.

View File

@@ -49,12 +49,12 @@ class AsyncVectorEnv(VectorEnv):
>>> import gymnasium as gym >>> import gymnasium as gym
>>> env = gym.vector.AsyncVectorEnv([ >>> env = gym.vector.AsyncVectorEnv([
... lambda: gym.make("Pendulum-v0", g=9.81), ... lambda: gym.make("Pendulum-v1", g=9.81),
... lambda: gym.make("Pendulum-v0", g=1.62) ... lambda: gym.make("Pendulum-v1", g=1.62)
... ]) ... ])
>>> env.reset() >>> env.reset(seed=42)
array([[-0.8286432 , 0.5597771 , 0.90249056], (array([[-0.14995256, 0.9886932 , -0.12224312],
[-0.85009176, 0.5266346 , 0.60007906]], dtype=float32) [ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32), {})
""" """
def __init__( def __init__(

View File

@@ -20,12 +20,12 @@ class SyncVectorEnv(VectorEnv):
>>> import gymnasium as gym >>> import gymnasium as gym
>>> env = gym.vector.SyncVectorEnv([ >>> env = gym.vector.SyncVectorEnv([
... lambda: gym.make("Pendulum-v0", g=9.81), ... lambda: gym.make("Pendulum-v1", g=9.81),
... lambda: gym.make("Pendulum-v0", g=1.62) ... lambda: gym.make("Pendulum-v1", g=1.62)
... ]) ... ])
>>> env.reset() >>> env.reset(seed=42)
array([[-0.8286432 , 0.5597771 , 0.90249056], (array([[-0.14995256, 0.9886932 , -0.12224312],
[-0.85009176, 0.5266346 , 0.60007906]], dtype=float32) [ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32), {})
""" """
def __init__( def __init__(

View File

@@ -28,12 +28,14 @@ def concatenate(
Example:: Example::
>>> from gymnasium.spaces import Box >>> from gymnasium.spaces import Box
>>> space = Box(low=0, high=1, shape=(3,), dtype=np.float32) >>> import numpy as np
>>> space = Box(low=0, high=1, shape=(3,), seed=42, dtype=np.float32)
>>> out = np.zeros((2, 3), dtype=np.float32) >>> out = np.zeros((2, 3), dtype=np.float32)
>>> items = [space.sample() for _ in range(2)] >>> items = [space.sample() for _ in range(2)]
>>> concatenate(space, items, out) >>> concatenate(space, items, out)
array([[0.6348213 , 0.28607962, 0.60760117], array([[0.77395606, 0.43887845, 0.85859793],
[0.87383074, 0.192658 , 0.2148103 ]], dtype=float32) [0.697368 , 0.09417735, 0.97562236]], dtype=float32)
Args: Args:
space: Observation space of a single environment in the vectorized environment. space: Observation space of a single environment in the vectorized environment.
@@ -91,15 +93,17 @@ def create_empty_array(
Example:: Example::
>>> from gymnasium.spaces import Box, Dict >>> from gymnasium.spaces import Box, Dict
>>> import numpy as np
>>> space = Dict({ >>> space = Dict({
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32), ... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)}) ... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
>>> create_empty_array(space, n=2, fn=np.zeros) >>> create_empty_array(space, n=2, fn=np.zeros)
OrderedDict([('position', array([[0., 0., 0.], OrderedDict([('position', array([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)), [0., 0., 0.]], dtype=float32)), ('velocity', array([[0., 0.],
('velocity', array([[0., 0.],
[0., 0.]], dtype=float32))]) [0., 0.]], dtype=float32))])
Args: Args:
space: Observation space of a single environment in the vectorized environment. space: Observation space of a single environment in the vectorized environment.
n: Number of environments in the vectorized environment. If `None`, creates an empty sample from `space`. n: Number of environments in the vectorized environment. If `None`, creates an empty sample from `space`.

View File

@@ -30,12 +30,13 @@ def batch_space(space: Space, n: int = 1) -> Space:
Example:: Example::
>>> from gymnasium.spaces import Box, Dict >>> from gymnasium.spaces import Box, Dict
>>> import numpy as np
>>> space = Dict({ >>> space = Dict({
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32), ... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32) ... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)
... }) ... })
>>> batch_space(space, n=5) >>> batch_space(space, n=5)
Dict(position:Box(5, 3), velocity:Box(5, 2)) Dict('position': Box(0.0, 1.0, (5, 3), float32), 'velocity': Box(0.0, 1.0, (5, 2), float32))
Args: Args:
space: Space (e.g. the observation space) for a single environment in the vectorized environment. space: Space (e.g. the observation space) for a single environment in the vectorized environment.
@@ -140,18 +141,17 @@ def iterate(space: Space, items) -> Iterator:
Example:: Example::
>>> from gymnasium.spaces import Box, Dict >>> from gymnasium.spaces import Box, Dict
>>> import numpy as np
>>> space = Dict({ >>> space = Dict({
... 'position': Box(low=0, high=1, shape=(2, 3), dtype=np.float32), ... 'position': Box(low=0, high=1, shape=(2, 3), seed=42, dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2, 2), dtype=np.float32)}) ... 'velocity': Box(low=0, high=1, shape=(2, 2), seed=42, dtype=np.float32)})
>>> items = space.sample() >>> items = space.sample()
>>> it = iterate(space, items) >>> it = iterate(space, items)
>>> next(it) >>> next(it)
{'position': array([-0.99644893, -0.08304597, -0.7238421 ], dtype=float32), OrderedDict([('position', array([0.77395606, 0.43887845, 0.85859793], dtype=float32)), ('velocity', array([0.77395606, 0.43887845], dtype=float32))])
'velocity': array([0.35848552, 0.1533453 ], dtype=float32)}
>>> next(it)
{'position': array([-0.67958736, -0.49076623, 0.38661423], dtype=float32),
'velocity': array([0.7975036 , 0.93317133], dtype=float32)}
>>> next(it) >>> next(it)
OrderedDict([('position', array([0.697368 , 0.09417735, 0.97562236], dtype=float32)), ('velocity', array([0.85859793, 0.697368 ], dtype=float32))])
>>> next(it) # doctest: +SKIP
StopIteration StopIteration
Args: Args:

View File

@@ -129,11 +129,12 @@ class VectorEnv(gym.Env):
>>> import gymnasium as gym >>> import gymnasium as gym
>>> envs = gym.vector.make("CartPole-v1", num_envs=3) >>> envs = gym.vector.make("CartPole-v1", num_envs=3)
>>> envs.reset() >>> envs.reset(seed=42)
(array([[-0.02240574, -0.03439831, -0.03904812, 0.02810693], (array([[ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ],
[ 0.01586068, 0.01929009, 0.02394426, 0.04016077], [ 0.01522993, -0.04562247, -0.04799704, 0.03392126],
[-0.01314174, 0.03893502, -0.02400815, 0.0038326 ]], [-0.03774345, -0.02418869, -0.00942293, 0.0469184 ]],
dtype=float32), {}) dtype=float32), {})
""" """
self.reset_async(seed=seed, options=options) self.reset_async(seed=seed, options=options)
return self.reset_wait(seed=seed, options=options) return self.reset_wait(seed=seed, options=options)
@@ -176,14 +177,13 @@ class VectorEnv(gym.Env):
An example:: An example::
>>> envs = gym.vector.make("CartPole-v1", num_envs=3) >>> envs = gym.vector.make("CartPole-v1", num_envs=3)
>>> envs.reset() >>> _ = envs.reset(seed=42)
>>> actions = np.array([1, 0, 1]) >>> actions = np.array([1, 0, 1])
>>> observations, rewards, termination, truncation, infos = envs.step(actions) >>> observations, rewards, termination, truncation, infos = envs.step(actions)
>>> observations >>> observations
array([[ 0.00122802, 0.16228443, 0.02521779, -0.23700266], array([[ 0.02727336, 0.18847767, 0.03625453, -0.26141977],
[ 0.00788269, -0.17490888, 0.03393489, 0.31735462], [ 0.01431748, -0.24002443, -0.04731862, 0.3110827 ],
[ 0.04918966, 0.19421194, 0.02938497, -0.29495203]], [-0.03822722, 0.1710671 , -0.00848456, -0.2487226 ]],
dtype=float32) dtype=float32)
>>> rewards >>> rewards
array([1., 1., 1.]) array([1., 1., 1.])

View File

@@ -10,29 +10,29 @@ with (possibly optional) parameters to the wrapper's constructor.
>>> import gymnasium as gym >>> import gymnasium as gym
>>> from gymnasium.wrappers import RescaleAction >>> from gymnasium.wrappers import RescaleAction
>>> base_env = gym.make("BipedalWalker-v3") >>> base_env = gym.make("Hopper-v4")
>>> base_env.action_space >>> base_env.action_space
Box([-1. -1. -1. -1.], [1. 1. 1. 1.], (4,), float32) Box(-1.0, 1.0, (3,), float32)
>>> wrapped_env = RescaleAction(base_env, min_action=0, max_action=1) >>> wrapped_env = RescaleAction(base_env, min_action=0, max_action=1)
>>> wrapped_env.action_space >>> wrapped_env.action_space
Box([0. 0. 0. 0.], [1. 1. 1. 1.], (4,), float32) Box(-1.0, 1.0, (3,), float32)
You can access the environment underneath the **first** wrapper by using the :attr:`gymnasium.Wrapper.env` attribute. You can access the environment underneath the **first** wrapper by using the :attr:`gymnasium.Wrapper.env` attribute.
As the :class:`gymnasium.Wrapper` class inherits from :class:`gymnasium.Env` then :attr:`gymnasium.Wrapper.env` can be another wrapper. As the :class:`gymnasium.Wrapper` class inherits from :class:`gymnasium.Env` then :attr:`gymnasium.Wrapper.env` can be another wrapper.
>>> wrapped_env >>> wrapped_env
<RescaleAction<TimeLimit<OrderEnforcing<BipedalWalker<BipedalWalker-v3>>>>> <RescaleAction<TimeLimit<OrderEnforcing<PassiveEnvChecker<HopperEnv<Hopper-v4>>>>>>
>>> wrapped_env.env >>> wrapped_env.env
<TimeLimit<OrderEnforcing<BipedalWalker<BipedalWalker-v3>>>> <TimeLimit<OrderEnforcing<PassiveEnvChecker<HopperEnv<Hopper-v4>>>>>
If you want to get to the environment underneath **all** of the layers of wrappers, you can use the If you want to get to the environment underneath **all** of the layers of wrappers, you can use the
:attr:`gymnasium.Wrapper.unwrapped` attribute. :attr:`gymnasium.Wrapper.unwrapped` attribute.
If the environment is already a bare environment, the :attr:`gymnasium.Wrapper.unwrapped` attribute will just return itself. If the environment is already a bare environment, the :attr:`gymnasium.Wrapper.unwrapped` attribute will just return itself.
>>> wrapped_env >>> wrapped_env
<RescaleAction<TimeLimit<OrderEnforcing<BipedalWalker<BipedalWalker-v3>>>>> <RescaleAction<TimeLimit<OrderEnforcing<PassiveEnvChecker<HopperEnv<Hopper-v4>>>>>>
>>> wrapped_env.unwrapped >>> wrapped_env.unwrapped # doctest: +SKIP
<gymnasium.envs.box2d.bipedal_walker.BipedalWalker object at 0x7f87d70712d0> <gymnasium.envs.mujoco.hopper_v4.HopperEnv object at 0x7fbb5efd0490>
There are three common things you might want a wrapper to do: There are three common things you might want a wrapper to do:

View File

@@ -11,12 +11,14 @@ class ClipAction(ActionWrapper):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> env = gym.make('Bipedal-Walker-v3') >>> from gymnasium.wrappers import ClipAction
>>> env = gym.make("Hopper-v4")
>>> env = ClipAction(env) >>> env = ClipAction(env)
>>> env.action_space >>> env.action_space
Box(-1.0, 1.0, (4,), float32) Box(-1.0, 1.0, (3,), float32)
>>> env.step(np.array([5.0, 2.0, -10.0, 0.0])) >>> _ = env.reset(seed=42)
# Executes the action np.array([1.0, 1.0, -1.0, 0]) in the base environment >>> _ = env.step(np.array([5.0, -2.0, 0.0]))
... # Executes the action np.array([1.0, -1.0, 0]) in the base environment
""" """
def __init__(self, env: gym.Env): def __init__(self, env: gym.Env):

View File

@@ -11,17 +11,17 @@ class FilterObservation(gym.ObservationWrapper):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> env = gym.wrappers.TransformObservation( >>> from gymnasium.wrappers import TransformObservation
... gym.make('CartPole-v1'), lambda obs: {'obs': obs, 'time': 0} >>> env = gym.make("CartPole-v1")
... ) >>> env = TransformObservation(env, lambda obs: {'obs': obs, 'time': 0})
>>> env.observation_space = gym.spaces.Dict(obs=env.observation_space, time=gym.spaces.Discrete(1)) >>> env.observation_space = gym.spaces.Dict(obs=env.observation_space, time=gym.spaces.Discrete(1))
>>> env.reset() >>> env.reset(seed=42)
{'obs': array([-0.00067088, -0.01860439, 0.04772898, -0.01911527], dtype=float32), 'time': 0} ({'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), 'time': 0}, {})
>>> env = FilterObservation(env, filter_keys=['obs']) >>> env = FilterObservation(env, filter_keys=['obs'])
>>> env.reset() >>> env.reset(seed=42)
{'obs': array([ 0.04560107, 0.04466959, -0.0328232 , -0.02367178], dtype=float32)} ({'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32)}, {})
>>> env.step(0) >>> env.step(0)
({'obs': array([ 0.04649447, -0.14996664, -0.03329664, 0.25847703], dtype=float32)}, 1.0, False, {}) ({'obs': array([ 0.02727336, -0.20172954, 0.03625453, 0.32351476], dtype=float32)}, 1.0, False, False, {})
""" """
def __init__(self, env: gym.Env, filter_keys: Sequence[str] = None): def __init__(self, env: gym.Env, filter_keys: Sequence[str] = None):

View File

@@ -8,13 +8,14 @@ class FlattenObservation(gym.ObservationWrapper):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> env = gym.make('CarRacing-v1') >>> from gymnasium.wrappers import FlattenObservation
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space.shape >>> env.observation_space.shape
(96, 96, 3) (96, 96, 3)
>>> env = FlattenObservation(env) >>> env = FlattenObservation(env)
>>> env.observation_space.shape >>> env.observation_space.shape
(27648,) (27648,)
>>> obs, info = env.reset() >>> obs, _ = env.reset()
>>> obs.shape >>> obs.shape
(27648,) (27648,)
""" """

View File

@@ -114,11 +114,12 @@ class FrameStack(gym.ObservationWrapper):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> env = gym.make('CarRacing-v1') >>> from gymnasium.wrappers import FrameStack
>>> env = gym.make("CarRacing-v2")
>>> env = FrameStack(env, 4) >>> env = FrameStack(env, 4)
>>> env.observation_space >>> env.observation_space
Box(4, 96, 96, 3) Box(0, 255, (4, 96, 96, 3), uint8)
>>> obs = env.reset() >>> obs, _ = env.reset()
>>> obs.shape >>> obs.shape
(4, 96, 96, 3) (4, 96, 96, 3)
""" """

View File

@@ -9,13 +9,15 @@ class GrayScaleObservation(gym.ObservationWrapper):
"""Convert the image observation from RGB to gray scale. """Convert the image observation from RGB to gray scale.
Example: Example:
>>> env = gym.make('CarRacing-v1') >>> import gymnasium as gym
>>> from gymnasium.wrappers import GrayScaleObservation
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space >>> env.observation_space
Box(0, 255, (96, 96, 3), uint8) Box(0, 255, (96, 96, 3), uint8)
>>> env = GrayScaleObservation(gym.make('CarRacing-v1')) >>> env = GrayScaleObservation(gym.make("CarRacing-v2"))
>>> env.observation_space >>> env.observation_space
Box(0, 255, (96, 96), uint8) Box(0, 255, (96, 96), uint8)
>>> env = GrayScaleObservation(gym.make('CarRacing-v1'), keep_dim=True) >>> env = GrayScaleObservation(gym.make("CarRacing-v2"), keep_dim=True)
>>> env.observation_space >>> env.observation_space
Box(0, 255, (96, 96, 1), uint8) Box(0, 255, (96, 96, 1), uint8)
""" """

View File

@@ -18,26 +18,28 @@ class HumanRendering(gym.Wrapper):
The ``render_mode`` of the wrapped environment must be either ``'rgb_array'`` or ``'rgb_array_list'``. The ``render_mode`` of the wrapped environment must be either ``'rgb_array'`` or ``'rgb_array_list'``.
Example: Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import HumanRendering
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array") >>> env = gym.make("LunarLander-v2", render_mode="rgb_array")
>>> wrapped = HumanRendering(env) >>> wrapped = HumanRendering(env)
>>> wrapped.reset() # This will start rendering to the screen >>> obs, _ = wrapped.reset() # This will start rendering to the screen
The wrapper can also be applied directly when the environment is instantiated, simply by passing The wrapper can also be applied directly when the environment is instantiated, simply by passing
``render_mode="human"`` to ``make``. The wrapper will only be applied if the environment does not ``render_mode="human"`` to ``make``. The wrapper will only be applied if the environment does not
implement human-rendering natively (i.e. ``render_mode`` does not contain ``"human"``). implement human-rendering natively (i.e. ``render_mode`` does not contain ``"human"``).
Example: Example:
>>> env = gym.make("NoNativeRendering-v2", render_mode="human") # NoNativeRendering-v0 doesn't implement human-rendering natively >>> env = gym.make("CartPoleJax-v1", render_mode="human") # CartPoleJax-v1 doesn't implement human-rendering natively
>>> env.reset() # This will start rendering to the screen >>> obs, _ = env.reset() # This will start rendering to the screen
Warning: If the base environment uses ``render_mode="rgb_array_list"``, its (i.e. the *base environment's*) render method Warning: If the base environment uses ``render_mode="rgb_array_list"``, its (i.e. the *base environment's*) render method
will always return an empty list: will always return an empty list:
>>> env = gym.make("LunarLander-v2", render_mode="rgb_array_list") >>> env = gym.make("LunarLander-v2", render_mode="rgb_array_list")
>>> wrapped = HumanRendering(env) >>> wrapped = HumanRendering(env)
>>> wrapped.reset() >>> obs, _ = wrapped.reset()
>>> env.render() >>> env.render() # env.render() will always return an empty list!
[] # env.render() will always return an empty list! []
""" """

View File

@@ -7,16 +7,17 @@ class OrderEnforcing(gym.Wrapper):
"""A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`. """A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
Example: Example:
>>> from gymnasium.envs.classic_control import CartPoleEnv >>> import gymnasium as gym
>>> env = CartPoleEnv() >>> from gymnasium.wrappers import OrderEnforcing
>>> env = gym.make("CartPole-v1", render_mode="human")
>>> env = OrderEnforcing(env) >>> env = OrderEnforcing(env)
>>> env.step(0) >>> env.step(0) # doctest: +SKIP
ResetNeeded: Cannot call env.step() before calling env.reset() gymnasium.error.ResetNeeded: Cannot call env.step() before calling env.reset()
>>> env.render() # doctest: +SKIP
gymnasium.error.ResetNeeded('Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper.')
>>> _ = env.reset()
>>> env.render() >>> env.render()
ResetNeeded: Cannot call env.render() before calling env.reset() >>> _ = env.step(0)
>>> env.reset()
>>> env.render()
>>> env.step(0)
""" """
def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False): def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False):

View File

@@ -25,22 +25,23 @@ class PixelObservationWrapper(gym.ObservationWrapper):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> env = PixelObservationWrapper(gym.make('CarRacing-v1', render_mode="rgb_array")) >>> from gymnasium.wrappers import PixelObservationWrapper
>>> obs = env.reset() >>> env = PixelObservationWrapper(gym.make("CarRacing-v2", render_mode="rgb_array"))
>>> obs, _ = env.reset()
>>> obs.keys() >>> obs.keys()
odict_keys(['pixels']) odict_keys(['pixels'])
>>> obs['pixels'].shape >>> obs['pixels'].shape
(400, 600, 3) (400, 600, 3)
>>> env = PixelObservationWrapper(gym.make('CarRacing-v1', render_mode="rgb_array"), pixels_only=False) >>> env = PixelObservationWrapper(gym.make("CarRacing-v2", render_mode="rgb_array"), pixels_only=False)
>>> obs = env.reset() >>> obs, _ = env.reset()
>>> obs.keys() >>> obs.keys()
odict_keys(['state', 'pixels']) odict_keys(['state', 'pixels'])
>>> obs['state'].shape >>> obs['state'].shape
(96, 96, 3) (96, 96, 3)
>>> obs['pixels'].shape >>> obs['pixels'].shape
(400, 600, 3) (400, 600, 3)
>>> env = PixelObservationWrapper(gym.make('CarRacing-v1', render_mode="rgb_array"), pixel_keys=('obs',)) >>> env = PixelObservationWrapper(gym.make("CarRacing-v2", render_mode="rgb_array"), pixel_keys=('obs',))
>>> obs = env.reset() >>> obs, _ = env.reset()
>>> obs.keys() >>> obs.keys()
odict_keys(['obs']) odict_keys(['obs'])
>>> obs['obs'].shape >>> obs['obs'].shape

View File

@@ -19,7 +19,6 @@ class RecordEpisodeStatistics(gym.Wrapper):
After the completion of an episode, ``info`` will look like this:: After the completion of an episode, ``info`` will look like this::
>>> info = { >>> info = {
... ...
... "episode": { ... "episode": {
... "r": "<cumulative reward>", ... "r": "<cumulative reward>",
... "l": "<episode length>", ... "l": "<episode length>",
@@ -30,7 +29,10 @@ class RecordEpisodeStatistics(gym.Wrapper):
For a vectorized environments the output will be in the form of:: For a vectorized environments the output will be in the form of::
>>> infos = { >>> infos = {
... ... ... "final_observation": "<array of length num-envs>",
... "_final_observation": "<boolean array of length num-envs>",
... "final_info": "<array of length num-envs>",
... "_final_info": "<boolean array of length num-envs>",
... "episode": { ... "episode": {
... "r": "<array of cumulative reward>", ... "r": "<array of cumulative reward>",
... "l": "<array of episode length>", ... "l": "<array of episode length>",

View File

@@ -15,15 +15,17 @@ class RescaleAction(gym.ActionWrapper):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> env = gym.make('BipedalWalker-v3') >>> from gymnasium.wrappers import RescaleAction
>>> env.action_space >>> import numpy as np
Box(-1.0, 1.0, (4,), float32) >>> env = gym.make("Hopper-v4")
>>> _ = env.reset(seed=42)
>>> obs, _, _, _, _ = env.step(np.array([1,1,1]))
>>> _ = env.reset(seed=42)
>>> min_action = -0.5 >>> min_action = -0.5
>>> max_action = np.array([0.0, 0.5, 1.0, 0.75]) >>> max_action = np.array([0.0, 0.5, 0.75])
>>> env = RescaleAction(env, min_action=min_action, max_action=max_action) >>> wrapped_env = RescaleAction(env, min_action=min_action, max_action=max_action)
>>> env.action_space >>> wrapped_env_obs, _, _, _, _ = wrapped_env.step(max_action)
Box(-0.5, [0. 0.5 1. 0.75], (4,), float32) >>> np.alltrue(obs == wrapped_env_obs)
>>> RescaleAction(env, min_action, max_action).action_space == gym.spaces.Box(min_action, max_action)
True True
""" """

View File

@@ -20,7 +20,8 @@ class ResizeObservation(gym.ObservationWrapper):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> env = gym.make('CarRacing-v1') >>> from gymnasium.wrappers import ResizeObservation
>>> env = gym.make("CarRacing-v2")
>>> env.observation_space.shape >>> env.observation_space.shape
(96, 96, 3) (96, 96, 3)
>>> env = ResizeObservation(env, 64) >>> env = ResizeObservation(env, 64)

View File

@@ -16,12 +16,14 @@ class StepAPICompatibility(gym.Wrapper):
output_truncation_bool (bool): Apply to convert environment to use new step API that returns two bool. (True by default) output_truncation_bool (bool): Apply to convert environment to use new step API that returns two bool. (True by default)
Examples: Examples:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import StepAPICompatibility
>>> env = gym.make("CartPole-v1") >>> env = gym.make("CartPole-v1")
>>> env # wrapper not applied by default, set to new API >>> env # wrapper not applied by default, set to new API
<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>> <TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>
>>> env = gym.make("CartPole-v1", apply_api_compatibility=True) # set to old API >>> env = StepAPICompatibility(gym.make("CartPole-v1"))
>>> env
<StepAPICompatibility<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>> <StepAPICompatibility<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>>
>>> env = StepAPICompatibility(CustomEnv(), output_truncation_bool=False) # manually using wrapper on unregistered envs
""" """

View File

@@ -13,12 +13,14 @@ class TimeAwareObservation(gym.ObservationWrapper):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> env = gym.make('CartPole-v1') >>> from gymnasium.wrappers import TimeAwareObservation
>>> env = gym.make("CartPole-v1")
>>> env = TimeAwareObservation(env) >>> env = TimeAwareObservation(env)
>>> env.reset() >>> env.reset(seed=42)
array([ 0.03810719, 0.03522411, 0.02231044, -0.01088205, 0. ]) (array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 , 0. ]), {})
>>> _ = env.action_space.seed(42)
>>> env.step(env.action_space.sample())[0] >>> env.step(env.action_space.sample())[0]
array([ 0.03881167, -0.16021058, 0.0220928 , 0.28875574, 1. ]) array([ 0.02727336, -0.20172954, 0.03625453, 0.32351476, 1. ])
""" """
def __init__(self, env: gym.Env): def __init__(self, env: gym.Env):

View File

@@ -11,9 +11,9 @@ class TimeLimit(gym.Wrapper):
Critically, this is different from the `terminated` signal that originates from the underlying environment as part of the MDP. Critically, this is different from the `terminated` signal that originates from the underlying environment as part of the MDP.
Example: Example:
>>> from gymnasium.envs.classic_control import CartPoleEnv >>> import gymnasium as gym
>>> from gymnasium.wrappers import TimeLimit >>> from gymnasium.wrappers import TimeLimit
>>> env = CartPoleEnv() >>> env = gym.make("CartPole-v1")
>>> env = TimeLimit(env, max_episode_steps=1000) >>> env = TimeLimit(env, max_episode_steps=1000)
""" """

View File

@@ -13,11 +13,13 @@ class TransformObservation(gym.ObservationWrapper):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> from gymnasium.wrappers import TransformObservation
>>> import numpy as np >>> import numpy as np
>>> env = gym.make('CartPole-v1') >>> np.random.seed(0)
>>> env = TransformObservation(env, lambda obs: obs + 0.1*np.random.randn(*obs.shape)) >>> env = gym.make("CartPole-v1")
>>> env.reset() >>> env = TransformObservation(env, lambda obs: obs + 0.1 * np.random.randn(*obs.shape))
array([-0.08319338, 0.04635121, -0.07394746, 0.20877492]) >>> env.reset(seed=42)
(array([0.20380084, 0.03390356, 0.13373359, 0.24382612]), {})
""" """
def __init__(self, env: gym.Env, f: Callable[[Any], Any]): def __init__(self, env: gym.Env, f: Callable[[Any], Any]):

View File

@@ -13,9 +13,10 @@ class TransformReward(RewardWrapper):
Example: Example:
>>> import gymnasium as gym >>> import gymnasium as gym
>>> env = gym.make('CartPole-v1') >>> from gymnasium.wrappers import TransformReward
>>> env = gym.make("CartPole-v1")
>>> env = TransformReward(env, lambda r: 0.01*r) >>> env = TransformReward(env, lambda r: 0.01*r)
>>> env.reset() >>> _ = env.reset()
>>> observation, reward, terminated, truncated, info = env.step(env.action_space.sample()) >>> observation, reward, terminated, truncated, info = env.step(env.action_space.sample())
>>> reward >>> reward
0.01 0.01

View File

@@ -18,14 +18,28 @@ class VectorListInfo(gym.Wrapper):
i.e. `VectorListInfo(RecordEpisodeStatistics(envs))` i.e. `VectorListInfo(RecordEpisodeStatistics(envs))`
Example:: Example::
>>> # As dict:
>>> # actual >>> infos = {
>>> { ... "final_observation": "<array of length num-envs>",
... "k": np.array[0., 0., 0.5, 0.3], ... "_final_observation": "<boolean array of length num-envs>",
... "_k": np.array[False, False, True, True] ... "final_info": "<array of length num-envs>",
... "_final_info": "<boolean array of length num-envs>",
... "episode": {
... "r": "<array of cumulative reward>",
... "l": "<array of episode length>",
... "t": "<array of elapsed time since beginning of episode>"
... },
... "_episode": "<boolean array of length num-envs>"
... } ... }
>>> # classic >>> # As list:
>>> [{}, {}, {k: 0.5}, {k: 0.3}] >>> infos = [
... {
... "episode": {"r": "<cumulative reward>", "l": "<episode length>", "t": "<elapsed time since beginning of episode>"},
... "final_observation": "<observation>",
... "final_info": {},
... },
... ...,
... ]
""" """
def __init__(self, env): def __init__(self, env):