mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 05:44:31 +00:00
Versioned Action wrappers which supports jumpy (#150)
This commit is contained in:
@@ -5,7 +5,11 @@ from typing import TypeVar
|
||||
|
||||
ArgType = TypeVar("ArgType")
|
||||
|
||||
from gymnasium.experimental.wrappers.lambda_action import LambdaActionV0
|
||||
from gymnasium.experimental.wrappers.lambda_action import (
|
||||
LambdaActionV0,
|
||||
ClipActionV0,
|
||||
RescaleActionV0,
|
||||
)
|
||||
from gymnasium.experimental.wrappers.lambda_observations import LambdaObservationV0
|
||||
from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0
|
||||
|
||||
@@ -13,6 +17,8 @@ __all__ = [
|
||||
"ArgType",
|
||||
# Lambda Action
|
||||
"LambdaActionV0",
|
||||
"ClipActionV0",
|
||||
"RescaleActionV0",
|
||||
# Lambda Observation
|
||||
"LambdaObservationV0",
|
||||
# Lambda Reward
|
||||
|
@@ -1,8 +1,11 @@
|
||||
"""Lambda action wrapper which apply a function to the provided action."""
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from typing import Any, Callable
|
||||
import jumpy as jp
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ActType
|
||||
from gymnasium.experimental.wrappers import ArgType
|
||||
|
||||
@@ -28,3 +31,93 @@ class LambdaActionV0(gym.ActionWrapper):
|
||||
def action(self, action: ActType) -> Any:
|
||||
"""Apply function to action."""
|
||||
return self.func(action)
|
||||
|
||||
|
||||
class ClipActionV0(LambdaActionV0):
|
||||
"""Clip the continuous action within the valid :class:`Box` observation space bound.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> import numpy as np
|
||||
>>> env = gym.make('BipedalWalker-v3', disable_env_checker=True)
|
||||
>>> env = ClipActionV0(env)
|
||||
>>> env.action_space
|
||||
Box(-1.0, 1.0, (4,), float32)
|
||||
>>> env.step(np.array([5.0, 2.0, -10.0, 0.0]))
|
||||
# Executes the action np.array([1.0, 1.0, -1.0, 0]) in the base environment
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env):
|
||||
"""A wrapper for clipping continuous actions within the valid bound.
|
||||
|
||||
Args:
|
||||
env: The environment to apply the wrapper
|
||||
"""
|
||||
assert isinstance(env.action_space, spaces.Box)
|
||||
super().__init__(
|
||||
env,
|
||||
lambda action: jp.clip(action, env.action_space.low, env.action_space.high),
|
||||
)
|
||||
|
||||
self.action_space = spaces.Box(-np.inf, np.inf, env.action_space.shape)
|
||||
|
||||
|
||||
class RescaleActionV0(LambdaActionV0):
|
||||
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
|
||||
|
||||
The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action`
|
||||
or :attr:`max_action` are numpy arrays, the shape must match the shape of the environment's action space.
|
||||
|
||||
Example:
|
||||
>>> import gymnasium as gym
|
||||
>>> import numpy as np
|
||||
>>> env = gym.make('BipedalWalker-v3', disable_env_checker=True)
|
||||
>>> _ = env.reset(seed=42)
|
||||
>>> obs, _, _, _, _ = env.step(np.array([1,1,1,1]))
|
||||
>>> _ = env.reset(seed=42)
|
||||
>>> min_action = -0.5
|
||||
>>> max_action = np.array([0.0, 0.5, 1.0, 0.75])
|
||||
>>> wrapped_env = RescaleActionV0(env, min_action=min_action, max_action=max_action)
|
||||
>>> wrapped_env_obs, _, _, _, _ = wrapped_env.step(max_action)
|
||||
>>> np.alltrue(obs == wrapped_env_obs)
|
||||
True
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
min_action: Union[float, int, np.ndarray],
|
||||
max_action: Union[float, int, np.ndarray],
|
||||
):
|
||||
"""Initializes the :class:`RescaleAction` wrapper.
|
||||
|
||||
Args:
|
||||
env (Env): The environment to apply the wrapper
|
||||
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
|
||||
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
|
||||
"""
|
||||
assert isinstance(
|
||||
env.action_space, spaces.Box
|
||||
), f"expected Box action space, got {type(env.action_space)}"
|
||||
assert np.less_equal(min_action, max_action).all(), (min_action, max_action)
|
||||
|
||||
low = env.action_space.low
|
||||
high = env.action_space.high
|
||||
|
||||
self.min_action = np.full(
|
||||
env.action_space.shape, min_action, dtype=env.action_space.dtype
|
||||
)
|
||||
self.max_action = np.full(
|
||||
env.action_space.shape, max_action, dtype=env.action_space.dtype
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
env,
|
||||
lambda action: jp.clip(
|
||||
low
|
||||
+ (high - low)
|
||||
* ((action - self.min_action) / (self.max_action - self.min_action)),
|
||||
low,
|
||||
high,
|
||||
),
|
||||
)
|
||||
|
@@ -4,7 +4,7 @@ from typing import Union
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
from gymnasium.spaces import Box
|
||||
|
||||
|
||||
class RescaleAction(gym.ActionWrapper):
|
||||
@@ -41,7 +41,7 @@ class RescaleAction(gym.ActionWrapper):
|
||||
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
|
||||
"""
|
||||
assert isinstance(
|
||||
env.action_space, spaces.Box
|
||||
env.action_space, Box
|
||||
), f"expected Box action space, got {type(env.action_space)}"
|
||||
assert np.less_equal(min_action, max_action).all(), (min_action, max_action)
|
||||
|
||||
@@ -52,12 +52,6 @@ class RescaleAction(gym.ActionWrapper):
|
||||
self.max_action = (
|
||||
np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + max_action
|
||||
)
|
||||
self.action_space = spaces.Box(
|
||||
low=min_action,
|
||||
high=max_action,
|
||||
shape=env.action_space.shape,
|
||||
dtype=env.action_space.dtype,
|
||||
)
|
||||
|
||||
def action(self, action):
|
||||
"""Rescales the action affinely from [:attr:`min_action`, :attr:`max_action`] to the action space of the base environment, :attr:`env`.
|
||||
|
@@ -12,4 +12,4 @@ pygame==2.1.3.dev8
|
||||
ale-py~=0.8.0
|
||||
mujoco==2.2
|
||||
mujoco_py<2.2,>=2.1
|
||||
imageio>=2.14.1
|
||||
imageio>=2.14.1
|
1
setup.py
1
setup.py
@@ -86,6 +86,7 @@ setup(
|
||||
include_package_data=True,
|
||||
install_requires=[
|
||||
"numpy >= 1.21.0",
|
||||
"jax-jumpy >= 0.2.0",
|
||||
"cloudpickle >= 1.2.0",
|
||||
"importlib_metadata >= 4.8.0; python_version < '3.10'",
|
||||
"gymnasium_notices >= 0.0.1",
|
||||
|
47
tests/experimental/wrappers/test_clip_action.py
Normal file
47
tests/experimental/wrappers/test_clip_action.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Test suite for LambdaActionV0."""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.experimental.wrappers import ClipActionV0
|
||||
|
||||
SEED = 42
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("env", "action_unclipped_env", "action_clipped_env"),
|
||||
(
|
||||
[
|
||||
# MountainCar action space: Box(-1.0, 1.0, (1,), float32)
|
||||
gym.make("MountainCarContinuous-v0"),
|
||||
np.array([1]),
|
||||
np.array([1.5]),
|
||||
],
|
||||
[
|
||||
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
|
||||
gym.make("BipedalWalker-v3"),
|
||||
np.array([1, 1, 1, 1]),
|
||||
np.array([10, 10, 10, 10]),
|
||||
],
|
||||
[
|
||||
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
|
||||
gym.make("BipedalWalker-v3"),
|
||||
np.array([0.5, 0.5, 1, 1]),
|
||||
np.array([0.5, 0.5, 10, 10]),
|
||||
],
|
||||
),
|
||||
)
|
||||
def test_clip_actions_v0(env, action_unclipped_env, action_clipped_env):
|
||||
"""Tests if actions out of bound are correctly clipped.
|
||||
|
||||
Tests whether out of bound actions for the wrapped
|
||||
environments are correctly clipped.
|
||||
"""
|
||||
env.reset(seed=SEED)
|
||||
obs, _, _, _, _ = env.step(action_unclipped_env)
|
||||
|
||||
env.reset(seed=SEED)
|
||||
wrapped_env = ClipActionV0(env)
|
||||
wrapped_obs, _, _, _, _ = wrapped_env.step(action_clipped_env)
|
||||
|
||||
assert np.alltrue(obs == wrapped_obs)
|
@@ -12,7 +12,7 @@ NUM_ENVS = 3
|
||||
BOX_SPACE = Box(-5, 5, (1,), dtype=np.float64)
|
||||
|
||||
|
||||
def env_step_fn(self, action):
|
||||
def generic_step_fn(self, action):
|
||||
return 0, 0, False, False, {"action": action}
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ def env_step_fn(self, action):
|
||||
("env", "func", "action", "expected"),
|
||||
[
|
||||
(
|
||||
GenericTestEnv(action_space=BOX_SPACE, step_fn=env_step_fn),
|
||||
GenericTestEnv(action_space=BOX_SPACE, step_fn=generic_step_fn),
|
||||
lambda action: action + 2,
|
||||
1,
|
||||
3,
|
||||
|
51
tests/experimental/wrappers/test_rescale_action.py
Normal file
51
tests/experimental/wrappers/test_rescale_action.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Test suite for RescaleActionV0."""
|
||||
import jax
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.experimental.wrappers import RescaleActionV0
|
||||
|
||||
SEED = 42
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("env", "low", "high", "action", "scaled_action"),
|
||||
[
|
||||
(
|
||||
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
|
||||
gym.make("BipedalWalker-v3"),
|
||||
-0.5,
|
||||
0.5,
|
||||
np.array([1, 1, 1, 1]),
|
||||
np.array([0.5, 0.5, 0.5, 0.5]),
|
||||
),
|
||||
(
|
||||
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
|
||||
gym.make("BipedalWalker-v3"),
|
||||
-0.5,
|
||||
0.5,
|
||||
jax.numpy.array([1, 1, 1, 1]),
|
||||
jax.numpy.array([0.5, 0.5, 0.5, 0.5]),
|
||||
),
|
||||
(
|
||||
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
|
||||
gym.make("BipedalWalker-v3"),
|
||||
np.array([-0.5, -0.5, -1, -1], dtype=np.float32),
|
||||
np.array([0.5, 0.5, 1, 1], dtype=np.float32),
|
||||
jax.numpy.array([1, 1, 1, 1]),
|
||||
jax.numpy.array([0.5, 0.5, 1, 1]),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_rescale_actions_v0_box(env, low, high, action, scaled_action):
|
||||
"""Test action rescaling."""
|
||||
env.reset(seed=SEED)
|
||||
obs, _, _, _, _ = env.step(action)
|
||||
|
||||
env.reset(seed=SEED)
|
||||
wrapped_env = RescaleActionV0(env, low, high)
|
||||
|
||||
obs_scaled, _, _, _, _ = wrapped_env.step(scaled_action)
|
||||
|
||||
assert np.alltrue(obs == obs_scaled)
|
Reference in New Issue
Block a user