Update RescaleAction and RescaleObservation for np.inf bounds (#1095)

This commit is contained in:
Tim Schneider
2024-07-03 15:53:40 +02:00
committed by GitHub
parent b064b684bb
commit fc55d47039
12 changed files with 178 additions and 133 deletions

View File

@@ -27,7 +27,7 @@ title: Env
>>> env.action_space
Discrete(2)
>>> env.observation_space
Box(-3.4028234663852886e+38, 3.4028234663852886e+38, (4,), float32)
Box(-inf, inf, (4,), float32)
.. autoattribute:: gymnasium.Env.observation_space
@@ -36,9 +36,9 @@ title: Env
.. code::
>>> env.observation_space.high
array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38], dtype=float32)
array([4.8000002e+00, inf, 4.1887903e-01, inf], dtype=float32)
>>> env.observation_space.low
array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38], dtype=float32)
array([-4.8000002e+00, -inf, -4.1887903e-01, -inf], dtype=float32)
.. autoattribute:: gymnasium.Env.metadata

View File

@@ -140,9 +140,9 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
high = np.array(
[
self.x_threshold * 2,
np.finfo(np.float32).max,
np.inf,
self.theta_threshold_radians * 2,
np.finfo(np.float32).max,
np.inf,
],
dtype=np.float32,
)
@@ -401,9 +401,9 @@ class CartPoleVectorEnv(VectorEnv):
high = np.array(
[
self.x_threshold * 2,
np.finfo(np.float32).max,
np.inf,
self.theta_threshold_radians * 2,
np.finfo(np.float32).max,
np.inf,
],
dtype=np.float32,
)

View File

@@ -59,16 +59,13 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
>>> envs.action_space
MultiDiscrete([2 2 2])
>>> envs.observation_space
Box([[-4.80000019e+00 -3.40282347e+38 -4.18879032e-01 -3.40282347e+38
0.00000000e+00]
[-4.80000019e+00 -3.40282347e+38 -4.18879032e-01 -3.40282347e+38
0.00000000e+00]
[-4.80000019e+00 -3.40282347e+38 -4.18879032e-01 -3.40282347e+38
0.00000000e+00]], [[4.80000019e+00 3.40282347e+38 4.18879032e-01 3.40282347e+38
Box([[-4.80000019 -inf -0.41887903 -inf 0. ]
[-4.80000019 -inf -0.41887903 -inf 0. ]
[-4.80000019 -inf -0.41887903 -inf 0. ]], [[4.80000019e+00 inf 4.18879032e-01 inf
5.00000000e+02]
[4.80000019e+00 3.40282347e+38 4.18879032e-01 3.40282347e+38
[4.80000019e+00 inf 4.18879032e-01 inf
5.00000000e+02]
[4.80000019e+00 3.40282347e+38 4.18879032e-01 3.40282347e+38
[4.80000019e+00 inf 4.18879032e-01 inf
5.00000000e+02]], (3, 5), float64)
>>> observations, infos = envs.reset(seed=123)
>>> observations

View File

@@ -129,8 +129,7 @@ class TimeAwareObservation(
>>> env = gym.make("CartPole-v1")
>>> env = TimeAwareObservation(env)
>>> env.observation_space
Box([-4.80000019e+00 -3.40282347e+38 -4.18879032e-01 -3.40282347e+38
0.00000000e+00], [4.80000019e+00 3.40282347e+38 4.18879032e-01 3.40282347e+38
Box([-4.80000019 -inf -0.41887903 -inf 0. ], [4.80000019e+00 inf 4.18879032e-01 inf
5.00000000e+02], (5,), float64)
>>> env.reset(seed=42)[0]
array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 , 0. ])
@@ -142,8 +141,7 @@ class TimeAwareObservation(
>>> env = gym.make('CartPole-v1')
>>> env = TimeAwareObservation(env, normalize_time=True)
>>> env.observation_space
Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38
0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 1.0000000e+00], (5,), float32)
Box([-4.8 -inf -0.41887903 -inf 0. ], [4.8 inf 0.41887903 inf 1. ], (5,), float32)
>>> env.reset(seed=42)[0]
array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 , 0. ],
dtype=float32)
@@ -156,7 +154,7 @@ class TimeAwareObservation(
>>> env = gym.make("CartPole-v1")
>>> env = TimeAwareObservation(env, flatten=False)
>>> env.observation_space
Dict('obs': Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), 'time': Box(0, 500, (1,), int32))
Dict('obs': Box([-4.8 -inf -0.41887903 -inf], [4.8 inf 0.41887903 inf], (4,), float32), 'time': Box(0, 500, (1,), int32))
>>> env.reset(seed=42)[0]
{'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), 'time': array([0], dtype=int32)}
>>> _ = env.action_space.seed(42)

View File

@@ -18,6 +18,8 @@ from gymnasium.spaces import Box, Space
__all__ = ["TransformAction", "ClipAction", "RescaleAction"]
from gymnasium.wrappers.utils import rescale_box
class TransformAction(
gym.ActionWrapper[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs
@@ -153,8 +155,8 @@ class RescaleAction(
def __init__(
self,
env: gym.Env[ObsType, ActType],
min_action: float | int | np.ndarray,
max_action: float | int | np.ndarray,
min_action: np.floating | np.integer | np.ndarray,
max_action: np.floating | np.integer | np.ndarray,
):
"""Constructor for the Rescale Action wrapper.
@@ -163,49 +165,16 @@ class RescaleAction(
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
"""
assert isinstance(env.action_space, Box)
gym.utils.RecordConstructorArgs.__init__(
self, min_action=min_action, max_action=max_action
)
assert isinstance(env.action_space, Box)
assert not np.any(env.action_space.low == np.inf) and not np.any(
env.action_space.high == np.inf
)
if not isinstance(min_action, np.ndarray):
assert np.issubdtype(type(min_action), np.integer) or np.issubdtype(
type(min_action), np.floating
)
min_action = np.full(env.action_space.shape, min_action)
assert min_action.shape == env.action_space.shape
assert not np.any(min_action == np.inf)
if not isinstance(max_action, np.ndarray):
assert np.issubdtype(type(max_action), np.integer) or np.issubdtype(
type(max_action), np.floating
)
max_action = np.full(env.action_space.shape, max_action)
assert max_action.shape == env.action_space.shape
assert not np.any(max_action == np.inf)
assert isinstance(env.action_space, Box)
assert np.all(np.less_equal(min_action, max_action))
# Imagine the x-axis between the old Box and the y-axis being the new Box
gradient = (env.action_space.high - env.action_space.low) / (
max_action - min_action
)
intercept = gradient * -min_action + env.action_space.low
act_space, _, func = rescale_box(env.action_space, min_action, max_action)
TransformAction.__init__(
self,
env=env,
func=lambda action: gradient * action + intercept,
action_space=Box(
low=min_action,
high=max_action,
shape=env.action_space.shape,
dtype=env.action_space.dtype,
),
func=func,
action_space=act_space,
)

View File

@@ -35,6 +35,8 @@ __all__ = [
"AddRenderObservation",
]
from gymnasium.wrappers.utils import rescale_box
class TransformObservation(
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
@@ -107,7 +109,7 @@ class FilterObservation(
>>> env = gym.make("CartPole-v1")
>>> env = gym.wrappers.TimeAwareObservation(env, flatten=False)
>>> env.observation_space
Dict('obs': Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), 'time': Box(0, 500, (1,), int32))
Dict('obs': Box([-4.8 -inf -0.41887903 -inf], [4.8 inf 0.41887903 inf], (4,), float32), 'time': Box(0, 500, (1,), int32))
>>> env.reset(seed=42)
({'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), 'time': array([0], dtype=int32)}, {})
>>> env = FilterObservation(env, filter_keys=['time'])
@@ -462,6 +464,8 @@ class RescaleObservation(
):
"""Affinely (linearly) rescales a ``Box`` observation space of the environment to within the range of ``[min_obs, max_obs]``.
For unbounded components in the original observation space, the corresponding target bounds must also be infinite and vice versa.
A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.RescaleObservation`.
Example:
@@ -492,57 +496,15 @@ class RescaleObservation(
max_obs: The new maximum observation bound
"""
assert isinstance(env.observation_space, spaces.Box)
assert not np.any(env.observation_space.low == np.inf) and not np.any(
env.observation_space.high == np.inf
)
if not isinstance(min_obs, np.ndarray):
assert np.issubdtype(type(min_obs), np.integer) or np.issubdtype(
type(max_obs), np.floating
)
min_obs = np.full(env.observation_space.shape, min_obs)
assert (
min_obs.shape == env.observation_space.shape
), f"{min_obs.shape}, {env.observation_space.shape}, {min_obs}, {env.observation_space.low}"
assert not np.any(min_obs == np.inf)
if not isinstance(max_obs, np.ndarray):
assert np.issubdtype(type(max_obs), np.integer) or np.issubdtype(
type(max_obs), np.floating
)
max_obs = np.full(env.observation_space.shape, max_obs)
assert max_obs.shape == env.observation_space.shape
assert not np.any(max_obs == np.inf)
self.min_obs = min_obs
self.max_obs = max_obs
# Imagine the x-axis between the old Box and the y-axis being the new Box
# float128 is not available everywhere
try:
high_low_diff_dtype = np.float128
except AttributeError:
high_low_diff_dtype = np.float64
high_low_diff = np.array(
env.observation_space.high, dtype=high_low_diff_dtype
) - np.array(env.observation_space.low, dtype=high_low_diff_dtype)
gradient = np.array(
(max_obs - min_obs) / high_low_diff, dtype=env.observation_space.dtype
)
intercept = gradient * -env.observation_space.low + min_obs
gym.utils.RecordConstructorArgs.__init__(self, min_obs=min_obs, max_obs=max_obs)
obs_space, func, _ = rescale_box(env.observation_space, min_obs, max_obs)
TransformObservation.__init__(
self,
env=env,
func=lambda obs: gradient * obs + intercept,
observation_space=spaces.Box(
low=min_obs,
high=max_obs,
shape=env.observation_space.shape,
dtype=env.observation_space.dtype,
),
func=func,
observation_space=obs_space,
)
@@ -642,7 +604,7 @@ class AddRenderObservation(
>>> env = gym.make("CartPole-v1", render_mode="rgb_array")
>>> env = AddRenderObservation(env, render_only=False)
>>> env.observation_space
Dict('pixels': Box(0, 255, (400, 600, 3), uint8), 'state': Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32))
Dict('pixels': Box(0, 255, (400, 600, 3), uint8), 'state': Box([-4.8 -inf -0.41887903 -inf], [4.8 inf 0.41887903 inf], (4,), float32))
>>> obs, info = env.reset(seed=123)
>>> obs.keys()
dict_keys(['state', 'pixels'])

View File

@@ -1,6 +1,9 @@
"""Utility functions for the wrappers."""
from __future__ import annotations
from functools import singledispatch
from typing import Callable
import numpy as np
@@ -149,3 +152,86 @@ def _create_graph_zero_array(space: Graph):
@create_zero_array.register(OneOf)
def _create_one_of_zero_array(space: OneOf):
return 0, create_zero_array(space.spaces[0])
def rescale_box(
box: Box,
new_min: np.floating | np.integer | np.ndarray,
new_max: np.floating | np.integer | np.ndarray,
) -> tuple[Box, Callable[[np.ndarray], np.ndarray], Callable[[np.ndarray], np.ndarray]]:
"""Rescale and shift the given box space to match the given bounds.
For unbounded components in the original space, the corresponding target bounds must also be infinite and vice versa.
Args:
box: The box space to rescale
new_min: The new minimum bound
new_max: The new maximum bound
Returns:
A tuple containing the rescaled box space, the forward transformation function (original -> rescaled) and the
backward transformation function (rescaled -> original).
"""
assert isinstance(box, Box)
if not isinstance(new_min, np.ndarray):
assert np.issubdtype(type(new_min), np.integer) or np.issubdtype(
type(new_min), np.floating
)
new_min = np.full(box.shape, new_min)
assert (
new_min.shape == box.shape
), f"{new_min.shape}, {box.shape}, {new_min}, {box.low}"
if not isinstance(new_max, np.ndarray):
assert np.issubdtype(type(new_max), np.integer) or np.issubdtype(
type(new_max), np.floating
)
new_max = np.full(box.shape, new_max)
assert new_max.shape == box.shape
assert np.all((new_min == box.low)[np.isinf(new_min) | np.isinf(box.low)])
assert np.all((new_max == box.high)[np.isinf(new_max) | np.isinf(box.high)])
assert np.all(new_min <= new_max)
assert np.all(box.low <= box.high)
# Imagine the x-axis between the old Box and the y-axis being the new Box
# float128 is not available everywhere
try:
high_low_diff_dtype = np.float128
except AttributeError:
high_low_diff_dtype = np.float64
min_finite = np.isfinite(new_min)
max_finite = np.isfinite(new_max)
both_finite = min_finite & max_finite
high_low_diff = np.array(
box.high[both_finite], dtype=high_low_diff_dtype
) - np.array(box.low[both_finite], dtype=high_low_diff_dtype)
gradient = np.ones_like(new_min, dtype=box.dtype)
gradient[both_finite] = (
new_max[both_finite] - new_min[both_finite]
) / high_low_diff
intercept = np.zeros_like(new_min, dtype=box.dtype)
# In cases where both are finite, the lower operation takes precedence
intercept[max_finite] = new_max[max_finite] - box.high[max_finite]
intercept[min_finite] = (
gradient[min_finite] * -box.low[min_finite] + new_min[min_finite]
)
new_box = Box(
low=new_min,
high=new_max,
shape=box.shape,
dtype=box.dtype,
)
def forward(obs: np.ndarray) -> np.ndarray:
return gradient * obs + intercept
def backward(obs: np.ndarray) -> np.ndarray:
return (obs - intercept) / gradient
return new_box, forward, backward

View File

@@ -318,18 +318,18 @@ class RescaleObservation(VectorizeTransformObservation):
Example:
>>> import gymnasium as gym
>>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
>>> envs = gym.make_vec("MountainCar-v0", num_envs=3, vectorization_mode="sync")
>>> obs, info = envs.reset(seed=123)
>>> obs.min()
np.float32(-0.0446179)
np.float32(-0.46352962)
>>> obs.max()
np.float32(0.0469136)
np.float32(0.0)
>>> envs = RescaleObservation(envs, min_obs=-5.0, max_obs=5.0)
>>> obs, info = envs.reset(seed=123)
>>> obs.min()
np.float32(-0.33379582)
np.float32(-0.90849805)
>>> obs.max()
np.float32(0.55998987)
np.float32(0.0)
>>> envs.close()
"""

View File

@@ -22,6 +22,16 @@ from gymnasium.utils.env_checker import (
from tests.testing_env import GenericTestEnv
CHECK_ENV_IGNORE_WARNINGS = [
f"\x1b[33mWARN: {message}\x1b[0m"
for message in [
"A Box observation space minimum value is -infinity. This is probably too low.",
"A Box observation space maximum value is infinity. This is probably too high.",
"For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.",
]
]
@pytest.mark.parametrize(
"env",
[
@@ -51,6 +61,11 @@ def test_no_error_warnings(env):
"""A full version of this test with all gymnasium envs is run in tests/envs/test_envs.py."""
with warnings.catch_warnings(record=True) as caught_warnings:
check_env(env)
caught_warnings = [
warning
for warning in caught_warnings
if str(warning.message) not in CHECK_ENV_IGNORE_WARNINGS
]
assert len(caught_warnings) == 0, [warning.message for warning in caught_warnings]

View File

@@ -12,28 +12,37 @@ def test_rescale_action_wrapper():
"""Test that the action is rescale within a min / max bound."""
env = GenericTestEnv(
step_func=record_action_step,
action_space=Box(np.array([0, 1]), np.array([1, 3])),
action_space=Box(
np.array([0, 1, -np.inf, 5, -np.inf], dtype=np.float32),
np.array([1, 3, np.inf, np.inf, 7], dtype=np.float32),
),
)
wrapped_env = RescaleAction(
env, min_action=np.array([-5, 0]), max_action=np.array([5, 1])
env,
min_action=np.array([-5, 0, -np.inf, -1, -np.inf], dtype=np.float32),
max_action=np.array([5, 1.0, np.inf, np.inf, 4], dtype=np.float32),
)
assert wrapped_env.action_space == Box(
np.array([-5, 0, -np.inf, -1, -np.inf], dtype=np.float32),
np.array([5, 1, np.inf, np.inf, 4], dtype=np.float32),
)
assert wrapped_env.action_space == Box(np.array([-5, 0]), np.array([5, 1]))
for sample_action, expected_action in (
(
np.array([0.0, 0.5], dtype=np.float32),
np.array([0.5, 2.0], dtype=np.float32),
np.array([0.0, 0.5, 7.0, -1.0, -23.0], dtype=np.float32),
np.array([0.5, 2.0, 7.0, 5.0, -20.0], dtype=np.float32),
),
(
np.array([-5.0, 0.0], dtype=np.float32),
np.array([0.0, 1.0], dtype=np.float32),
np.array([-5.0, 0.0, -4.0, 0.0, -3.0], dtype=np.float32),
np.array([0.0, 1.0, -4.0, 6.0, 0.0], dtype=np.float32),
),
(
np.array([5.0, 1.0], dtype=np.float32),
np.array([1.0, 3.0], dtype=np.float32),
np.array([5.0, 1.0, 0.0, 1.0, 4.0], dtype=np.float32),
np.array([1.0, 3.0, 0.0, 7.0, 7.0], dtype=np.float32),
),
):
assert sample_action in wrapped_env.action_space
assert expected_action in env.action_space
_, _, _, _, info = wrapped_env.step(sample_action)
assert np.all(info["action"] == expected_action)

View File

@@ -12,32 +12,34 @@ def test_rescale_observation():
"""Test the ``RescaleObservation`` wrapper."""
env = GenericTestEnv(
observation_space=Box(
np.array([0, 1], dtype=np.float32), np.array([1, 3], dtype=np.float32)
np.array([0, 1, -np.inf, 5, -np.inf], dtype=np.float32),
np.array([1, 3, np.inf, np.inf, 7], dtype=np.float32),
),
reset_func=record_obs_reset,
step_func=record_action_as_obs_step,
)
wrapped_env = RescaleObservation(
env,
min_obs=np.array([-5, 0], dtype=np.float32),
max_obs=np.array([5, 1], dtype=np.float32),
min_obs=np.array([-5, 0, -np.inf, -1, -np.inf], dtype=np.float32),
max_obs=np.array([5, 1.0, np.inf, np.inf, 4], dtype=np.float32),
)
assert wrapped_env.observation_space == Box(
np.array([-5, 0], dtype=np.float32), np.array([5, 1], dtype=np.float32)
np.array([-5, 0, -np.inf, -1, -np.inf], dtype=np.float32),
np.array([5, 1, np.inf, np.inf, 4], dtype=np.float32),
)
for sample_obs, expected_obs in (
(
np.array([0.5, 2.0], dtype=np.float32),
np.array([0.0, 0.5], dtype=np.float32),
np.array([0.5, 2.0, 7.0, 5.0, -20.0], dtype=np.float32),
np.array([0.0, 0.5, 7.0, -1.0, -23.0], dtype=np.float32),
),
(
np.array([0.0, 1.0], dtype=np.float32),
np.array([-5.0, 0.0], dtype=np.float32),
np.array([0.0, 1.0, -4.0, 6.0, 0.0], dtype=np.float32),
np.array([-5.0, 0.0, -4.0, 0.0, -3.0], dtype=np.float32),
),
(
np.array([1.0, 3.0], dtype=np.float32),
np.array([5.0, 1.0], dtype=np.float32),
np.array([1.0, 3.0, 0.0, 7.0, 7.0], dtype=np.float32),
np.array([5.0, 1.0, 0.0, 1.0, 4.0], dtype=np.float32),
),
):
assert sample_obs in env.observation_space

View File

@@ -45,7 +45,14 @@ def custom_environments():
("CarRacing-v2", "GrayscaleObservation", {}),
("CarRacing-v2", "ResizeObservation", {"shape": (35, 45)}),
("CarRacing-v2", "ReshapeObservation", {"shape": (96, 48, 6)}),
("CartPole-v1", "RescaleObservation", {"min_obs": 0, "max_obs": 1}),
(
"CartPole-v1",
"RescaleObservation",
{
"min_obs": np.array([0, -np.inf, 0, -np.inf]),
"max_obs": np.array([1, np.inf, 1, np.inf]),
},
),
("CarRacing-v2", "DtypeObservation", {"dtype": np.int32}),
# ("CartPole-v1", "RenderObservation", {}), # not implemented
# ("CartPole-v1", "TimeAwareObservation", {}), # not implemented