Versioned Action wrappers which supports jumpy (#150)

This commit is contained in:
Gianluca De Cola
2022-12-01 21:36:11 +01:00
committed by GitHub
parent 18f123da9a
commit 9b73630d9b
8 changed files with 205 additions and 13 deletions

View File

@@ -5,7 +5,11 @@ from typing import TypeVar
ArgType = TypeVar("ArgType")
from gymnasium.experimental.wrappers.lambda_action import LambdaActionV0
from gymnasium.experimental.wrappers.lambda_action import (
LambdaActionV0,
ClipActionV0,
RescaleActionV0,
)
from gymnasium.experimental.wrappers.lambda_observations import LambdaObservationV0
from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0
@@ -13,6 +17,8 @@ __all__ = [
"ArgType",
# Lambda Action
"LambdaActionV0",
"ClipActionV0",
"RescaleActionV0",
# Lambda Observation
"LambdaObservationV0",
# Lambda Reward

View File

@@ -1,8 +1,11 @@
"""Lambda action wrapper which apply a function to the provided action."""
from typing import Any, Callable, Union
from typing import Any, Callable
import jumpy as jp
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from gymnasium.core import ActType
from gymnasium.experimental.wrappers import ArgType
@@ -28,3 +31,93 @@ class LambdaActionV0(gym.ActionWrapper):
def action(self, action: ActType) -> Any:
"""Apply function to action."""
return self.func(action)
class ClipActionV0(LambdaActionV0):
"""Clip the continuous action within the valid :class:`Box` observation space bound.
Example:
>>> import gymnasium as gym
>>> import numpy as np
>>> env = gym.make('BipedalWalker-v3', disable_env_checker=True)
>>> env = ClipActionV0(env)
>>> env.action_space
Box(-1.0, 1.0, (4,), float32)
>>> env.step(np.array([5.0, 2.0, -10.0, 0.0]))
# Executes the action np.array([1.0, 1.0, -1.0, 0]) in the base environment
"""
def __init__(self, env: gym.Env):
"""A wrapper for clipping continuous actions within the valid bound.
Args:
env: The environment to apply the wrapper
"""
assert isinstance(env.action_space, spaces.Box)
super().__init__(
env,
lambda action: jp.clip(action, env.action_space.low, env.action_space.high),
)
self.action_space = spaces.Box(-np.inf, np.inf, env.action_space.shape)
class RescaleActionV0(LambdaActionV0):
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action`
or :attr:`max_action` are numpy arrays, the shape must match the shape of the environment's action space.
Example:
>>> import gymnasium as gym
>>> import numpy as np
>>> env = gym.make('BipedalWalker-v3', disable_env_checker=True)
>>> _ = env.reset(seed=42)
>>> obs, _, _, _, _ = env.step(np.array([1,1,1,1]))
>>> _ = env.reset(seed=42)
>>> min_action = -0.5
>>> max_action = np.array([0.0, 0.5, 1.0, 0.75])
>>> wrapped_env = RescaleActionV0(env, min_action=min_action, max_action=max_action)
>>> wrapped_env_obs, _, _, _, _ = wrapped_env.step(max_action)
>>> np.alltrue(obs == wrapped_env_obs)
True
"""
def __init__(
self,
env: gym.Env,
min_action: Union[float, int, np.ndarray],
max_action: Union[float, int, np.ndarray],
):
"""Initializes the :class:`RescaleAction` wrapper.
Args:
env (Env): The environment to apply the wrapper
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
"""
assert isinstance(
env.action_space, spaces.Box
), f"expected Box action space, got {type(env.action_space)}"
assert np.less_equal(min_action, max_action).all(), (min_action, max_action)
low = env.action_space.low
high = env.action_space.high
self.min_action = np.full(
env.action_space.shape, min_action, dtype=env.action_space.dtype
)
self.max_action = np.full(
env.action_space.shape, max_action, dtype=env.action_space.dtype
)
super().__init__(
env,
lambda action: jp.clip(
low
+ (high - low)
* ((action - self.min_action) / (self.max_action - self.min_action)),
low,
high,
),
)

View File

@@ -4,7 +4,7 @@ from typing import Union
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from gymnasium.spaces import Box
class RescaleAction(gym.ActionWrapper):
@@ -41,7 +41,7 @@ class RescaleAction(gym.ActionWrapper):
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
"""
assert isinstance(
env.action_space, spaces.Box
env.action_space, Box
), f"expected Box action space, got {type(env.action_space)}"
assert np.less_equal(min_action, max_action).all(), (min_action, max_action)
@@ -52,12 +52,6 @@ class RescaleAction(gym.ActionWrapper):
self.max_action = (
np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + max_action
)
self.action_space = spaces.Box(
low=min_action,
high=max_action,
shape=env.action_space.shape,
dtype=env.action_space.dtype,
)
def action(self, action):
"""Rescales the action affinely from [:attr:`min_action`, :attr:`max_action`] to the action space of the base environment, :attr:`env`.

View File

@@ -12,4 +12,4 @@ pygame==2.1.3.dev8
ale-py~=0.8.0
mujoco==2.2
mujoco_py<2.2,>=2.1
imageio>=2.14.1
imageio>=2.14.1

View File

@@ -86,6 +86,7 @@ setup(
include_package_data=True,
install_requires=[
"numpy >= 1.21.0",
"jax-jumpy >= 0.2.0",
"cloudpickle >= 1.2.0",
"importlib_metadata >= 4.8.0; python_version < '3.10'",
"gymnasium_notices >= 0.0.1",

View File

@@ -0,0 +1,47 @@
"""Test suite for LambdaActionV0."""
import numpy as np
import pytest
import gymnasium as gym
from gymnasium.experimental.wrappers import ClipActionV0
SEED = 42
@pytest.mark.parametrize(
("env", "action_unclipped_env", "action_clipped_env"),
(
[
# MountainCar action space: Box(-1.0, 1.0, (1,), float32)
gym.make("MountainCarContinuous-v0"),
np.array([1]),
np.array([1.5]),
],
[
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
gym.make("BipedalWalker-v3"),
np.array([1, 1, 1, 1]),
np.array([10, 10, 10, 10]),
],
[
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
gym.make("BipedalWalker-v3"),
np.array([0.5, 0.5, 1, 1]),
np.array([0.5, 0.5, 10, 10]),
],
),
)
def test_clip_actions_v0(env, action_unclipped_env, action_clipped_env):
"""Tests if actions out of bound are correctly clipped.
Tests whether out of bound actions for the wrapped
environments are correctly clipped.
"""
env.reset(seed=SEED)
obs, _, _, _, _ = env.step(action_unclipped_env)
env.reset(seed=SEED)
wrapped_env = ClipActionV0(env)
wrapped_obs, _, _, _, _ = wrapped_env.step(action_clipped_env)
assert np.alltrue(obs == wrapped_obs)

View File

@@ -12,7 +12,7 @@ NUM_ENVS = 3
BOX_SPACE = Box(-5, 5, (1,), dtype=np.float64)
def env_step_fn(self, action):
def generic_step_fn(self, action):
return 0, 0, False, False, {"action": action}
@@ -20,7 +20,7 @@ def env_step_fn(self, action):
("env", "func", "action", "expected"),
[
(
GenericTestEnv(action_space=BOX_SPACE, step_fn=env_step_fn),
GenericTestEnv(action_space=BOX_SPACE, step_fn=generic_step_fn),
lambda action: action + 2,
1,
3,

View File

@@ -0,0 +1,51 @@
"""Test suite for RescaleActionV0."""
import jax
import numpy as np
import pytest
import gymnasium as gym
from gymnasium.experimental.wrappers import RescaleActionV0
SEED = 42
@pytest.mark.parametrize(
("env", "low", "high", "action", "scaled_action"),
[
(
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
gym.make("BipedalWalker-v3"),
-0.5,
0.5,
np.array([1, 1, 1, 1]),
np.array([0.5, 0.5, 0.5, 0.5]),
),
(
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
gym.make("BipedalWalker-v3"),
-0.5,
0.5,
jax.numpy.array([1, 1, 1, 1]),
jax.numpy.array([0.5, 0.5, 0.5, 0.5]),
),
(
# BipedalWalker action space: Box(-1.0, 1.0, (4,), float32)
gym.make("BipedalWalker-v3"),
np.array([-0.5, -0.5, -1, -1], dtype=np.float32),
np.array([0.5, 0.5, 1, 1], dtype=np.float32),
jax.numpy.array([1, 1, 1, 1]),
jax.numpy.array([0.5, 0.5, 1, 1]),
),
],
)
def test_rescale_actions_v0_box(env, low, high, action, scaled_action):
"""Test action rescaling."""
env.reset(seed=SEED)
obs, _, _, _, _ = env.step(action)
env.reset(seed=SEED)
wrapped_env = RescaleActionV0(env, low, high)
obs_scaled, _, _, _, _ = wrapped_env.step(scaled_action)
assert np.alltrue(obs == obs_scaled)