diff --git a/gymnasium/experimental/wrappers/__init__.py b/gymnasium/experimental/wrappers/__init__.py index 556272fea..27e8a8156 100644 --- a/gymnasium/experimental/wrappers/__init__.py +++ b/gymnasium/experimental/wrappers/__init__.py @@ -5,7 +5,11 @@ from typing import TypeVar ArgType = TypeVar("ArgType") -from gymnasium.experimental.wrappers.lambda_action import LambdaActionV0 +from gymnasium.experimental.wrappers.lambda_action import ( + LambdaActionV0, + ClipActionV0, + RescaleActionV0, +) from gymnasium.experimental.wrappers.lambda_observations import LambdaObservationV0 from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0 @@ -13,6 +17,8 @@ __all__ = [ "ArgType", # Lambda Action "LambdaActionV0", + "ClipActionV0", + "RescaleActionV0", # Lambda Observation "LambdaObservationV0", # Lambda Reward diff --git a/gymnasium/experimental/wrappers/lambda_action.py b/gymnasium/experimental/wrappers/lambda_action.py index b858b5970..9cb7c4a30 100644 --- a/gymnasium/experimental/wrappers/lambda_action.py +++ b/gymnasium/experimental/wrappers/lambda_action.py @@ -1,8 +1,11 @@ """Lambda action wrapper which apply a function to the provided action.""" +from typing import Any, Callable, Union -from typing import Any, Callable +import jumpy as jp +import numpy as np import gymnasium as gym +from gymnasium import spaces from gymnasium.core import ActType from gymnasium.experimental.wrappers import ArgType @@ -28,3 +31,93 @@ class LambdaActionV0(gym.ActionWrapper): def action(self, action: ActType) -> Any: """Apply function to action.""" return self.func(action) + + +class ClipActionV0(LambdaActionV0): + """Clip the continuous action within the valid :class:`Box` observation space bound. + + Example: + >>> import gymnasium as gym + >>> import numpy as np + >>> env = gym.make('BipedalWalker-v3', disable_env_checker=True) + >>> env = ClipActionV0(env) + >>> env.action_space + Box(-1.0, 1.0, (4,), float32) + >>> env.step(np.array([5.0, 2.0, -10.0, 0.0])) + # Executes the action np.array([1.0, 1.0, -1.0, 0]) in the base environment + """ + + def __init__(self, env: gym.Env): + """A wrapper for clipping continuous actions within the valid bound. + + Args: + env: The environment to apply the wrapper + """ + assert isinstance(env.action_space, spaces.Box) + super().__init__( + env, + lambda action: jp.clip(action, env.action_space.low, env.action_space.high), + ) + + self.action_space = spaces.Box(-np.inf, np.inf, env.action_space.shape) + + +class RescaleActionV0(LambdaActionV0): + """Affinely rescales the continuous action space of the environment to the range [min_action, max_action]. + + The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action` + or :attr:`max_action` are numpy arrays, the shape must match the shape of the environment's action space. + + Example: + >>> import gymnasium as gym + >>> import numpy as np + >>> env = gym.make('BipedalWalker-v3', disable_env_checker=True) + >>> _ = env.reset(seed=42) + >>> obs, _, _, _, _ = env.step(np.array([1,1,1,1])) + >>> _ = env.reset(seed=42) + >>> min_action = -0.5 + >>> max_action = np.array([0.0, 0.5, 1.0, 0.75]) + >>> wrapped_env = RescaleActionV0(env, min_action=min_action, max_action=max_action) + >>> wrapped_env_obs, _, _, _, _ = wrapped_env.step(max_action) + >>> np.alltrue(obs == wrapped_env_obs) + True + """ + + def __init__( + self, + env: gym.Env, + min_action: Union[float, int, np.ndarray], + max_action: Union[float, int, np.ndarray], + ): + """Initializes the :class:`RescaleAction` wrapper. + + Args: + env (Env): The environment to apply the wrapper + min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar. + max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar. + """ + assert isinstance( + env.action_space, spaces.Box + ), f"expected Box action space, got {type(env.action_space)}" + assert np.less_equal(min_action, max_action).all(), (min_action, max_action) + + low = env.action_space.low + high = env.action_space.high + + self.min_action = np.full( + env.action_space.shape, min_action, dtype=env.action_space.dtype + ) + self.max_action = np.full( + env.action_space.shape, max_action, dtype=env.action_space.dtype + ) + + super().__init__( + env, + lambda action: jp.clip( + low + + (high - low) + * ((action - self.min_action) / (self.max_action - self.min_action)), + low, + high, + ), + ) diff --git a/gymnasium/wrappers/rescale_action.py b/gymnasium/wrappers/rescale_action.py index 9d8a89da0..c61c166a9 100644 --- a/gymnasium/wrappers/rescale_action.py +++ b/gymnasium/wrappers/rescale_action.py @@ -4,7 +4,7 @@ from typing import Union import numpy as np import gymnasium as gym -from gymnasium import spaces +from gymnasium.spaces import Box class RescaleAction(gym.ActionWrapper): @@ -41,7 +41,7 @@ class RescaleAction(gym.ActionWrapper): max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar. """ assert isinstance( - env.action_space, spaces.Box + env.action_space, Box ), f"expected Box action space, got {type(env.action_space)}" assert np.less_equal(min_action, max_action).all(), (min_action, max_action) @@ -52,12 +52,6 @@ class RescaleAction(gym.ActionWrapper): self.max_action = ( np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + max_action ) - self.action_space = spaces.Box( - low=min_action, - high=max_action, - shape=env.action_space.shape, - dtype=env.action_space.dtype, - ) def action(self, action): """Rescales the action affinely from [:attr:`min_action`, :attr:`max_action`] to the action space of the base environment, :attr:`env`. diff --git a/requirements.txt b/requirements.txt index 080d7c0fd..2dde29363 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,4 @@ pygame==2.1.3.dev8 ale-py~=0.8.0 mujoco==2.2 mujoco_py<2.2,>=2.1 -imageio>=2.14.1 +imageio>=2.14.1 \ No newline at end of file diff --git a/setup.py b/setup.py index 885e084d0..6d010c2bc 100644 --- a/setup.py +++ b/setup.py @@ -86,6 +86,7 @@ setup( include_package_data=True, install_requires=[ "numpy >= 1.21.0", + "jax-jumpy >= 0.2.0", "cloudpickle >= 1.2.0", "importlib_metadata >= 4.8.0; python_version < '3.10'", "gymnasium_notices >= 0.0.1", diff --git a/tests/experimental/wrappers/test_clip_action.py b/tests/experimental/wrappers/test_clip_action.py new file mode 100644 index 000000000..1ef9f4117 --- /dev/null +++ b/tests/experimental/wrappers/test_clip_action.py @@ -0,0 +1,47 @@ +"""Test suite for LambdaActionV0.""" +import numpy as np +import pytest + +import gymnasium as gym +from gymnasium.experimental.wrappers import ClipActionV0 + +SEED = 42 + + +@pytest.mark.parametrize( + ("env", "action_unclipped_env", "action_clipped_env"), + ( + [ + # MountainCar action space: Box(-1.0, 1.0, (1,), float32) + gym.make("MountainCarContinuous-v0"), + np.array([1]), + np.array([1.5]), + ], + [ + # BipedalWalker action space: Box(-1.0, 1.0, (4,), float32) + gym.make("BipedalWalker-v3"), + np.array([1, 1, 1, 1]), + np.array([10, 10, 10, 10]), + ], + [ + # BipedalWalker action space: Box(-1.0, 1.0, (4,), float32) + gym.make("BipedalWalker-v3"), + np.array([0.5, 0.5, 1, 1]), + np.array([0.5, 0.5, 10, 10]), + ], + ), +) +def test_clip_actions_v0(env, action_unclipped_env, action_clipped_env): + """Tests if actions out of bound are correctly clipped. + + Tests whether out of bound actions for the wrapped + environments are correctly clipped. + """ + env.reset(seed=SEED) + obs, _, _, _, _ = env.step(action_unclipped_env) + + env.reset(seed=SEED) + wrapped_env = ClipActionV0(env) + wrapped_obs, _, _, _, _ = wrapped_env.step(action_clipped_env) + + assert np.alltrue(obs == wrapped_obs) diff --git a/tests/experimental/wrappers/test_lambda_action.py b/tests/experimental/wrappers/test_lambda_action.py index 2e3b71fcb..a359a34a1 100644 --- a/tests/experimental/wrappers/test_lambda_action.py +++ b/tests/experimental/wrappers/test_lambda_action.py @@ -12,7 +12,7 @@ NUM_ENVS = 3 BOX_SPACE = Box(-5, 5, (1,), dtype=np.float64) -def env_step_fn(self, action): +def generic_step_fn(self, action): return 0, 0, False, False, {"action": action} @@ -20,7 +20,7 @@ def env_step_fn(self, action): ("env", "func", "action", "expected"), [ ( - GenericTestEnv(action_space=BOX_SPACE, step_fn=env_step_fn), + GenericTestEnv(action_space=BOX_SPACE, step_fn=generic_step_fn), lambda action: action + 2, 1, 3, diff --git a/tests/experimental/wrappers/test_rescale_action.py b/tests/experimental/wrappers/test_rescale_action.py new file mode 100644 index 000000000..4e9ba8997 --- /dev/null +++ b/tests/experimental/wrappers/test_rescale_action.py @@ -0,0 +1,51 @@ +"""Test suite for RescaleActionV0.""" +import jax +import numpy as np +import pytest + +import gymnasium as gym +from gymnasium.experimental.wrappers import RescaleActionV0 + +SEED = 42 + + +@pytest.mark.parametrize( + ("env", "low", "high", "action", "scaled_action"), + [ + ( + # BipedalWalker action space: Box(-1.0, 1.0, (4,), float32) + gym.make("BipedalWalker-v3"), + -0.5, + 0.5, + np.array([1, 1, 1, 1]), + np.array([0.5, 0.5, 0.5, 0.5]), + ), + ( + # BipedalWalker action space: Box(-1.0, 1.0, (4,), float32) + gym.make("BipedalWalker-v3"), + -0.5, + 0.5, + jax.numpy.array([1, 1, 1, 1]), + jax.numpy.array([0.5, 0.5, 0.5, 0.5]), + ), + ( + # BipedalWalker action space: Box(-1.0, 1.0, (4,), float32) + gym.make("BipedalWalker-v3"), + np.array([-0.5, -0.5, -1, -1], dtype=np.float32), + np.array([0.5, 0.5, 1, 1], dtype=np.float32), + jax.numpy.array([1, 1, 1, 1]), + jax.numpy.array([0.5, 0.5, 1, 1]), + ), + ], +) +def test_rescale_actions_v0_box(env, low, high, action, scaled_action): + """Test action rescaling.""" + env.reset(seed=SEED) + obs, _, _, _, _ = env.step(action) + + env.reset(seed=SEED) + wrapped_env = RescaleActionV0(env, low, high) + + obs_scaled, _, _, _, _ = wrapped_env.step(scaled_action) + + assert np.alltrue(obs == obs_scaled)