mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-02 06:16:32 +00:00
Update RescaleAction
and RescaleObservation
for np.inf
bounds (#1095)
This commit is contained in:
@@ -27,7 +27,7 @@ title: Env
|
||||
>>> env.action_space
|
||||
Discrete(2)
|
||||
>>> env.observation_space
|
||||
Box(-3.4028234663852886e+38, 3.4028234663852886e+38, (4,), float32)
|
||||
Box(-inf, inf, (4,), float32)
|
||||
|
||||
.. autoattribute:: gymnasium.Env.observation_space
|
||||
|
||||
@@ -36,9 +36,9 @@ title: Env
|
||||
.. code::
|
||||
|
||||
>>> env.observation_space.high
|
||||
array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38], dtype=float32)
|
||||
array([4.8000002e+00, inf, 4.1887903e-01, inf], dtype=float32)
|
||||
>>> env.observation_space.low
|
||||
array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38], dtype=float32)
|
||||
array([-4.8000002e+00, -inf, -4.1887903e-01, -inf], dtype=float32)
|
||||
|
||||
.. autoattribute:: gymnasium.Env.metadata
|
||||
|
||||
|
@@ -140,9 +140,9 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
||||
high = np.array(
|
||||
[
|
||||
self.x_threshold * 2,
|
||||
np.finfo(np.float32).max,
|
||||
np.inf,
|
||||
self.theta_threshold_radians * 2,
|
||||
np.finfo(np.float32).max,
|
||||
np.inf,
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
@@ -401,9 +401,9 @@ class CartPoleVectorEnv(VectorEnv):
|
||||
high = np.array(
|
||||
[
|
||||
self.x_threshold * 2,
|
||||
np.finfo(np.float32).max,
|
||||
np.inf,
|
||||
self.theta_threshold_radians * 2,
|
||||
np.finfo(np.float32).max,
|
||||
np.inf,
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
@@ -59,16 +59,13 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
|
||||
>>> envs.action_space
|
||||
MultiDiscrete([2 2 2])
|
||||
>>> envs.observation_space
|
||||
Box([[-4.80000019e+00 -3.40282347e+38 -4.18879032e-01 -3.40282347e+38
|
||||
0.00000000e+00]
|
||||
[-4.80000019e+00 -3.40282347e+38 -4.18879032e-01 -3.40282347e+38
|
||||
0.00000000e+00]
|
||||
[-4.80000019e+00 -3.40282347e+38 -4.18879032e-01 -3.40282347e+38
|
||||
0.00000000e+00]], [[4.80000019e+00 3.40282347e+38 4.18879032e-01 3.40282347e+38
|
||||
Box([[-4.80000019 -inf -0.41887903 -inf 0. ]
|
||||
[-4.80000019 -inf -0.41887903 -inf 0. ]
|
||||
[-4.80000019 -inf -0.41887903 -inf 0. ]], [[4.80000019e+00 inf 4.18879032e-01 inf
|
||||
5.00000000e+02]
|
||||
[4.80000019e+00 3.40282347e+38 4.18879032e-01 3.40282347e+38
|
||||
[4.80000019e+00 inf 4.18879032e-01 inf
|
||||
5.00000000e+02]
|
||||
[4.80000019e+00 3.40282347e+38 4.18879032e-01 3.40282347e+38
|
||||
[4.80000019e+00 inf 4.18879032e-01 inf
|
||||
5.00000000e+02]], (3, 5), float64)
|
||||
>>> observations, infos = envs.reset(seed=123)
|
||||
>>> observations
|
||||
|
@@ -129,8 +129,7 @@ class TimeAwareObservation(
|
||||
>>> env = gym.make("CartPole-v1")
|
||||
>>> env = TimeAwareObservation(env)
|
||||
>>> env.observation_space
|
||||
Box([-4.80000019e+00 -3.40282347e+38 -4.18879032e-01 -3.40282347e+38
|
||||
0.00000000e+00], [4.80000019e+00 3.40282347e+38 4.18879032e-01 3.40282347e+38
|
||||
Box([-4.80000019 -inf -0.41887903 -inf 0. ], [4.80000019e+00 inf 4.18879032e-01 inf
|
||||
5.00000000e+02], (5,), float64)
|
||||
>>> env.reset(seed=42)[0]
|
||||
array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 , 0. ])
|
||||
@@ -142,8 +141,7 @@ class TimeAwareObservation(
|
||||
>>> env = gym.make('CartPole-v1')
|
||||
>>> env = TimeAwareObservation(env, normalize_time=True)
|
||||
>>> env.observation_space
|
||||
Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38
|
||||
0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 1.0000000e+00], (5,), float32)
|
||||
Box([-4.8 -inf -0.41887903 -inf 0. ], [4.8 inf 0.41887903 inf 1. ], (5,), float32)
|
||||
>>> env.reset(seed=42)[0]
|
||||
array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 , 0. ],
|
||||
dtype=float32)
|
||||
@@ -156,7 +154,7 @@ class TimeAwareObservation(
|
||||
>>> env = gym.make("CartPole-v1")
|
||||
>>> env = TimeAwareObservation(env, flatten=False)
|
||||
>>> env.observation_space
|
||||
Dict('obs': Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), 'time': Box(0, 500, (1,), int32))
|
||||
Dict('obs': Box([-4.8 -inf -0.41887903 -inf], [4.8 inf 0.41887903 inf], (4,), float32), 'time': Box(0, 500, (1,), int32))
|
||||
>>> env.reset(seed=42)[0]
|
||||
{'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), 'time': array([0], dtype=int32)}
|
||||
>>> _ = env.action_space.seed(42)
|
||||
|
@@ -18,6 +18,8 @@ from gymnasium.spaces import Box, Space
|
||||
|
||||
__all__ = ["TransformAction", "ClipAction", "RescaleAction"]
|
||||
|
||||
from gymnasium.wrappers.utils import rescale_box
|
||||
|
||||
|
||||
class TransformAction(
|
||||
gym.ActionWrapper[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs
|
||||
@@ -153,8 +155,8 @@ class RescaleAction(
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env[ObsType, ActType],
|
||||
min_action: float | int | np.ndarray,
|
||||
max_action: float | int | np.ndarray,
|
||||
min_action: np.floating | np.integer | np.ndarray,
|
||||
max_action: np.floating | np.integer | np.ndarray,
|
||||
):
|
||||
"""Constructor for the Rescale Action wrapper.
|
||||
|
||||
@@ -163,49 +165,16 @@ class RescaleAction(
|
||||
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
|
||||
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
|
||||
"""
|
||||
assert isinstance(env.action_space, Box)
|
||||
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self, min_action=min_action, max_action=max_action
|
||||
)
|
||||
|
||||
assert isinstance(env.action_space, Box)
|
||||
assert not np.any(env.action_space.low == np.inf) and not np.any(
|
||||
env.action_space.high == np.inf
|
||||
)
|
||||
|
||||
if not isinstance(min_action, np.ndarray):
|
||||
assert np.issubdtype(type(min_action), np.integer) or np.issubdtype(
|
||||
type(min_action), np.floating
|
||||
)
|
||||
min_action = np.full(env.action_space.shape, min_action)
|
||||
|
||||
assert min_action.shape == env.action_space.shape
|
||||
assert not np.any(min_action == np.inf)
|
||||
|
||||
if not isinstance(max_action, np.ndarray):
|
||||
assert np.issubdtype(type(max_action), np.integer) or np.issubdtype(
|
||||
type(max_action), np.floating
|
||||
)
|
||||
max_action = np.full(env.action_space.shape, max_action)
|
||||
assert max_action.shape == env.action_space.shape
|
||||
assert not np.any(max_action == np.inf)
|
||||
|
||||
assert isinstance(env.action_space, Box)
|
||||
assert np.all(np.less_equal(min_action, max_action))
|
||||
|
||||
# Imagine the x-axis between the old Box and the y-axis being the new Box
|
||||
gradient = (env.action_space.high - env.action_space.low) / (
|
||||
max_action - min_action
|
||||
)
|
||||
intercept = gradient * -min_action + env.action_space.low
|
||||
|
||||
act_space, _, func = rescale_box(env.action_space, min_action, max_action)
|
||||
TransformAction.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda action: gradient * action + intercept,
|
||||
action_space=Box(
|
||||
low=min_action,
|
||||
high=max_action,
|
||||
shape=env.action_space.shape,
|
||||
dtype=env.action_space.dtype,
|
||||
),
|
||||
func=func,
|
||||
action_space=act_space,
|
||||
)
|
||||
|
@@ -35,6 +35,8 @@ __all__ = [
|
||||
"AddRenderObservation",
|
||||
]
|
||||
|
||||
from gymnasium.wrappers.utils import rescale_box
|
||||
|
||||
|
||||
class TransformObservation(
|
||||
gym.ObservationWrapper[WrapperObsType, ActType, ObsType],
|
||||
@@ -107,7 +109,7 @@ class FilterObservation(
|
||||
>>> env = gym.make("CartPole-v1")
|
||||
>>> env = gym.wrappers.TimeAwareObservation(env, flatten=False)
|
||||
>>> env.observation_space
|
||||
Dict('obs': Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), 'time': Box(0, 500, (1,), int32))
|
||||
Dict('obs': Box([-4.8 -inf -0.41887903 -inf], [4.8 inf 0.41887903 inf], (4,), float32), 'time': Box(0, 500, (1,), int32))
|
||||
>>> env.reset(seed=42)
|
||||
({'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), 'time': array([0], dtype=int32)}, {})
|
||||
>>> env = FilterObservation(env, filter_keys=['time'])
|
||||
@@ -462,6 +464,8 @@ class RescaleObservation(
|
||||
):
|
||||
"""Affinely (linearly) rescales a ``Box`` observation space of the environment to within the range of ``[min_obs, max_obs]``.
|
||||
|
||||
For unbounded components in the original observation space, the corresponding target bounds must also be infinite and vice versa.
|
||||
|
||||
A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.RescaleObservation`.
|
||||
|
||||
Example:
|
||||
@@ -492,57 +496,15 @@ class RescaleObservation(
|
||||
max_obs: The new maximum observation bound
|
||||
"""
|
||||
assert isinstance(env.observation_space, spaces.Box)
|
||||
assert not np.any(env.observation_space.low == np.inf) and not np.any(
|
||||
env.observation_space.high == np.inf
|
||||
)
|
||||
|
||||
if not isinstance(min_obs, np.ndarray):
|
||||
assert np.issubdtype(type(min_obs), np.integer) or np.issubdtype(
|
||||
type(max_obs), np.floating
|
||||
)
|
||||
min_obs = np.full(env.observation_space.shape, min_obs)
|
||||
assert (
|
||||
min_obs.shape == env.observation_space.shape
|
||||
), f"{min_obs.shape}, {env.observation_space.shape}, {min_obs}, {env.observation_space.low}"
|
||||
assert not np.any(min_obs == np.inf)
|
||||
|
||||
if not isinstance(max_obs, np.ndarray):
|
||||
assert np.issubdtype(type(max_obs), np.integer) or np.issubdtype(
|
||||
type(max_obs), np.floating
|
||||
)
|
||||
max_obs = np.full(env.observation_space.shape, max_obs)
|
||||
assert max_obs.shape == env.observation_space.shape
|
||||
assert not np.any(max_obs == np.inf)
|
||||
|
||||
self.min_obs = min_obs
|
||||
self.max_obs = max_obs
|
||||
|
||||
# Imagine the x-axis between the old Box and the y-axis being the new Box
|
||||
# float128 is not available everywhere
|
||||
try:
|
||||
high_low_diff_dtype = np.float128
|
||||
except AttributeError:
|
||||
high_low_diff_dtype = np.float64
|
||||
high_low_diff = np.array(
|
||||
env.observation_space.high, dtype=high_low_diff_dtype
|
||||
) - np.array(env.observation_space.low, dtype=high_low_diff_dtype)
|
||||
gradient = np.array(
|
||||
(max_obs - min_obs) / high_low_diff, dtype=env.observation_space.dtype
|
||||
)
|
||||
|
||||
intercept = gradient * -env.observation_space.low + min_obs
|
||||
|
||||
gym.utils.RecordConstructorArgs.__init__(self, min_obs=min_obs, max_obs=max_obs)
|
||||
|
||||
obs_space, func, _ = rescale_box(env.observation_space, min_obs, max_obs)
|
||||
TransformObservation.__init__(
|
||||
self,
|
||||
env=env,
|
||||
func=lambda obs: gradient * obs + intercept,
|
||||
observation_space=spaces.Box(
|
||||
low=min_obs,
|
||||
high=max_obs,
|
||||
shape=env.observation_space.shape,
|
||||
dtype=env.observation_space.dtype,
|
||||
),
|
||||
func=func,
|
||||
observation_space=obs_space,
|
||||
)
|
||||
|
||||
|
||||
@@ -642,7 +604,7 @@ class AddRenderObservation(
|
||||
>>> env = gym.make("CartPole-v1", render_mode="rgb_array")
|
||||
>>> env = AddRenderObservation(env, render_only=False)
|
||||
>>> env.observation_space
|
||||
Dict('pixels': Box(0, 255, (400, 600, 3), uint8), 'state': Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32))
|
||||
Dict('pixels': Box(0, 255, (400, 600, 3), uint8), 'state': Box([-4.8 -inf -0.41887903 -inf], [4.8 inf 0.41887903 inf], (4,), float32))
|
||||
>>> obs, info = env.reset(seed=123)
|
||||
>>> obs.keys()
|
||||
dict_keys(['state', 'pixels'])
|
||||
|
@@ -1,6 +1,9 @@
|
||||
"""Utility functions for the wrappers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import singledispatch
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -149,3 +152,86 @@ def _create_graph_zero_array(space: Graph):
|
||||
@create_zero_array.register(OneOf)
|
||||
def _create_one_of_zero_array(space: OneOf):
|
||||
return 0, create_zero_array(space.spaces[0])
|
||||
|
||||
|
||||
def rescale_box(
|
||||
box: Box,
|
||||
new_min: np.floating | np.integer | np.ndarray,
|
||||
new_max: np.floating | np.integer | np.ndarray,
|
||||
) -> tuple[Box, Callable[[np.ndarray], np.ndarray], Callable[[np.ndarray], np.ndarray]]:
|
||||
"""Rescale and shift the given box space to match the given bounds.
|
||||
|
||||
For unbounded components in the original space, the corresponding target bounds must also be infinite and vice versa.
|
||||
|
||||
Args:
|
||||
box: The box space to rescale
|
||||
new_min: The new minimum bound
|
||||
new_max: The new maximum bound
|
||||
|
||||
Returns:
|
||||
A tuple containing the rescaled box space, the forward transformation function (original -> rescaled) and the
|
||||
backward transformation function (rescaled -> original).
|
||||
"""
|
||||
assert isinstance(box, Box)
|
||||
|
||||
if not isinstance(new_min, np.ndarray):
|
||||
assert np.issubdtype(type(new_min), np.integer) or np.issubdtype(
|
||||
type(new_min), np.floating
|
||||
)
|
||||
new_min = np.full(box.shape, new_min)
|
||||
assert (
|
||||
new_min.shape == box.shape
|
||||
), f"{new_min.shape}, {box.shape}, {new_min}, {box.low}"
|
||||
|
||||
if not isinstance(new_max, np.ndarray):
|
||||
assert np.issubdtype(type(new_max), np.integer) or np.issubdtype(
|
||||
type(new_max), np.floating
|
||||
)
|
||||
new_max = np.full(box.shape, new_max)
|
||||
assert new_max.shape == box.shape
|
||||
assert np.all((new_min == box.low)[np.isinf(new_min) | np.isinf(box.low)])
|
||||
assert np.all((new_max == box.high)[np.isinf(new_max) | np.isinf(box.high)])
|
||||
assert np.all(new_min <= new_max)
|
||||
assert np.all(box.low <= box.high)
|
||||
|
||||
# Imagine the x-axis between the old Box and the y-axis being the new Box
|
||||
# float128 is not available everywhere
|
||||
try:
|
||||
high_low_diff_dtype = np.float128
|
||||
except AttributeError:
|
||||
high_low_diff_dtype = np.float64
|
||||
|
||||
min_finite = np.isfinite(new_min)
|
||||
max_finite = np.isfinite(new_max)
|
||||
both_finite = min_finite & max_finite
|
||||
|
||||
high_low_diff = np.array(
|
||||
box.high[both_finite], dtype=high_low_diff_dtype
|
||||
) - np.array(box.low[both_finite], dtype=high_low_diff_dtype)
|
||||
|
||||
gradient = np.ones_like(new_min, dtype=box.dtype)
|
||||
gradient[both_finite] = (
|
||||
new_max[both_finite] - new_min[both_finite]
|
||||
) / high_low_diff
|
||||
|
||||
intercept = np.zeros_like(new_min, dtype=box.dtype)
|
||||
# In cases where both are finite, the lower operation takes precedence
|
||||
intercept[max_finite] = new_max[max_finite] - box.high[max_finite]
|
||||
intercept[min_finite] = (
|
||||
gradient[min_finite] * -box.low[min_finite] + new_min[min_finite]
|
||||
)
|
||||
|
||||
new_box = Box(
|
||||
low=new_min,
|
||||
high=new_max,
|
||||
shape=box.shape,
|
||||
dtype=box.dtype,
|
||||
)
|
||||
|
||||
def forward(obs: np.ndarray) -> np.ndarray:
|
||||
return gradient * obs + intercept
|
||||
|
||||
def backward(obs: np.ndarray) -> np.ndarray:
|
||||
return (obs - intercept) / gradient
|
||||
|
||||
return new_box, forward, backward
|
||||
|
@@ -318,18 +318,18 @@ class RescaleObservation(VectorizeTransformObservation):
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
|
||||
>>> envs = gym.make_vec("MountainCar-v0", num_envs=3, vectorization_mode="sync")
|
||||
>>> obs, info = envs.reset(seed=123)
|
||||
>>> obs.min()
|
||||
np.float32(-0.0446179)
|
||||
np.float32(-0.46352962)
|
||||
>>> obs.max()
|
||||
np.float32(0.0469136)
|
||||
np.float32(0.0)
|
||||
>>> envs = RescaleObservation(envs, min_obs=-5.0, max_obs=5.0)
|
||||
>>> obs, info = envs.reset(seed=123)
|
||||
>>> obs.min()
|
||||
np.float32(-0.33379582)
|
||||
np.float32(-0.90849805)
|
||||
>>> obs.max()
|
||||
np.float32(0.55998987)
|
||||
np.float32(0.0)
|
||||
>>> envs.close()
|
||||
"""
|
||||
|
||||
|
@@ -22,6 +22,16 @@ from gymnasium.utils.env_checker import (
|
||||
from tests.testing_env import GenericTestEnv
|
||||
|
||||
|
||||
CHECK_ENV_IGNORE_WARNINGS = [
|
||||
f"\x1b[33mWARN: {message}\x1b[0m"
|
||||
for message in [
|
||||
"A Box observation space minimum value is -infinity. This is probably too low.",
|
||||
"A Box observation space maximum value is infinity. This is probably too high.",
|
||||
"For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.",
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env",
|
||||
[
|
||||
@@ -51,6 +61,11 @@ def test_no_error_warnings(env):
|
||||
"""A full version of this test with all gymnasium envs is run in tests/envs/test_envs.py."""
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
check_env(env)
|
||||
caught_warnings = [
|
||||
warning
|
||||
for warning in caught_warnings
|
||||
if str(warning.message) not in CHECK_ENV_IGNORE_WARNINGS
|
||||
]
|
||||
|
||||
assert len(caught_warnings) == 0, [warning.message for warning in caught_warnings]
|
||||
|
||||
|
@@ -12,28 +12,37 @@ def test_rescale_action_wrapper():
|
||||
"""Test that the action is rescale within a min / max bound."""
|
||||
env = GenericTestEnv(
|
||||
step_func=record_action_step,
|
||||
action_space=Box(np.array([0, 1]), np.array([1, 3])),
|
||||
action_space=Box(
|
||||
np.array([0, 1, -np.inf, 5, -np.inf], dtype=np.float32),
|
||||
np.array([1, 3, np.inf, np.inf, 7], dtype=np.float32),
|
||||
),
|
||||
)
|
||||
wrapped_env = RescaleAction(
|
||||
env, min_action=np.array([-5, 0]), max_action=np.array([5, 1])
|
||||
env,
|
||||
min_action=np.array([-5, 0, -np.inf, -1, -np.inf], dtype=np.float32),
|
||||
max_action=np.array([5, 1.0, np.inf, np.inf, 4], dtype=np.float32),
|
||||
)
|
||||
assert wrapped_env.action_space == Box(
|
||||
np.array([-5, 0, -np.inf, -1, -np.inf], dtype=np.float32),
|
||||
np.array([5, 1, np.inf, np.inf, 4], dtype=np.float32),
|
||||
)
|
||||
assert wrapped_env.action_space == Box(np.array([-5, 0]), np.array([5, 1]))
|
||||
|
||||
for sample_action, expected_action in (
|
||||
(
|
||||
np.array([0.0, 0.5], dtype=np.float32),
|
||||
np.array([0.5, 2.0], dtype=np.float32),
|
||||
np.array([0.0, 0.5, 7.0, -1.0, -23.0], dtype=np.float32),
|
||||
np.array([0.5, 2.0, 7.0, 5.0, -20.0], dtype=np.float32),
|
||||
),
|
||||
(
|
||||
np.array([-5.0, 0.0], dtype=np.float32),
|
||||
np.array([0.0, 1.0], dtype=np.float32),
|
||||
np.array([-5.0, 0.0, -4.0, 0.0, -3.0], dtype=np.float32),
|
||||
np.array([0.0, 1.0, -4.0, 6.0, 0.0], dtype=np.float32),
|
||||
),
|
||||
(
|
||||
np.array([5.0, 1.0], dtype=np.float32),
|
||||
np.array([1.0, 3.0], dtype=np.float32),
|
||||
np.array([5.0, 1.0, 0.0, 1.0, 4.0], dtype=np.float32),
|
||||
np.array([1.0, 3.0, 0.0, 7.0, 7.0], dtype=np.float32),
|
||||
),
|
||||
):
|
||||
assert sample_action in wrapped_env.action_space
|
||||
assert expected_action in env.action_space
|
||||
|
||||
_, _, _, _, info = wrapped_env.step(sample_action)
|
||||
assert np.all(info["action"] == expected_action)
|
||||
|
@@ -12,32 +12,34 @@ def test_rescale_observation():
|
||||
"""Test the ``RescaleObservation`` wrapper."""
|
||||
env = GenericTestEnv(
|
||||
observation_space=Box(
|
||||
np.array([0, 1], dtype=np.float32), np.array([1, 3], dtype=np.float32)
|
||||
np.array([0, 1, -np.inf, 5, -np.inf], dtype=np.float32),
|
||||
np.array([1, 3, np.inf, np.inf, 7], dtype=np.float32),
|
||||
),
|
||||
reset_func=record_obs_reset,
|
||||
step_func=record_action_as_obs_step,
|
||||
)
|
||||
wrapped_env = RescaleObservation(
|
||||
env,
|
||||
min_obs=np.array([-5, 0], dtype=np.float32),
|
||||
max_obs=np.array([5, 1], dtype=np.float32),
|
||||
min_obs=np.array([-5, 0, -np.inf, -1, -np.inf], dtype=np.float32),
|
||||
max_obs=np.array([5, 1.0, np.inf, np.inf, 4], dtype=np.float32),
|
||||
)
|
||||
assert wrapped_env.observation_space == Box(
|
||||
np.array([-5, 0], dtype=np.float32), np.array([5, 1], dtype=np.float32)
|
||||
np.array([-5, 0, -np.inf, -1, -np.inf], dtype=np.float32),
|
||||
np.array([5, 1, np.inf, np.inf, 4], dtype=np.float32),
|
||||
)
|
||||
|
||||
for sample_obs, expected_obs in (
|
||||
(
|
||||
np.array([0.5, 2.0], dtype=np.float32),
|
||||
np.array([0.0, 0.5], dtype=np.float32),
|
||||
np.array([0.5, 2.0, 7.0, 5.0, -20.0], dtype=np.float32),
|
||||
np.array([0.0, 0.5, 7.0, -1.0, -23.0], dtype=np.float32),
|
||||
),
|
||||
(
|
||||
np.array([0.0, 1.0], dtype=np.float32),
|
||||
np.array([-5.0, 0.0], dtype=np.float32),
|
||||
np.array([0.0, 1.0, -4.0, 6.0, 0.0], dtype=np.float32),
|
||||
np.array([-5.0, 0.0, -4.0, 0.0, -3.0], dtype=np.float32),
|
||||
),
|
||||
(
|
||||
np.array([1.0, 3.0], dtype=np.float32),
|
||||
np.array([5.0, 1.0], dtype=np.float32),
|
||||
np.array([1.0, 3.0, 0.0, 7.0, 7.0], dtype=np.float32),
|
||||
np.array([5.0, 1.0, 0.0, 1.0, 4.0], dtype=np.float32),
|
||||
),
|
||||
):
|
||||
assert sample_obs in env.observation_space
|
||||
|
@@ -45,7 +45,14 @@ def custom_environments():
|
||||
("CarRacing-v2", "GrayscaleObservation", {}),
|
||||
("CarRacing-v2", "ResizeObservation", {"shape": (35, 45)}),
|
||||
("CarRacing-v2", "ReshapeObservation", {"shape": (96, 48, 6)}),
|
||||
("CartPole-v1", "RescaleObservation", {"min_obs": 0, "max_obs": 1}),
|
||||
(
|
||||
"CartPole-v1",
|
||||
"RescaleObservation",
|
||||
{
|
||||
"min_obs": np.array([0, -np.inf, 0, -np.inf]),
|
||||
"max_obs": np.array([1, np.inf, 1, np.inf]),
|
||||
},
|
||||
),
|
||||
("CarRacing-v2", "DtypeObservation", {"dtype": np.int32}),
|
||||
# ("CartPole-v1", "RenderObservation", {}), # not implemented
|
||||
# ("CartPole-v1", "TimeAwareObservation", {}), # not implemented
|
||||
|
Reference in New Issue
Block a user