Files
Gymnasium/gymnasium/experimental/wrappers/lambda_action.py
2022-12-05 19:14:56 +00:00

155 lines
5.4 KiB
Python

"""A collection of wrappers that all use the LambdaAction class.
* ``LambdaAction`` - Transforms the actions based on a function
* ``ClipAction`` - Clips the action within a bounds
* ``RescaleAction`` - Rescales the action within a minimum and maximum actions
"""
from __future__ import annotations
from typing import Callable
import jumpy as jp
import numpy as np
import gymnasium as gym
from gymnasium.core import ActType, WrapperActType
from gymnasium.spaces import Box, Space
class LambdaActionV0(gym.ActionWrapper):
"""A wrapper that provides a function to modify the action passed to :meth:`step`."""
def __init__(
self,
env: gym.Env,
func: Callable[[WrapperActType], ActType],
action_space: Space | None,
):
"""Initialize LambdaAction.
Args:
env: The gymnasium environment
func: Function to apply to ``step`` ``action``
action_space: The updated action space of the wrapper given the function.
"""
super().__init__(env)
if action_space is not None:
self.action_space = action_space
self.func = func
def action(self, action: WrapperActType) -> ActType:
"""Apply function to action."""
return self.func(action)
class ClipActionV0(LambdaActionV0):
"""Clip the continuous action within the valid :class:`Box` observation space bound.
Example:
>>> import gymnasium as gym
>>> import numpy as np
>>> env = gym.make('BipedalWalker-v3', disable_env_checker=True)
>>> env = ClipActionV0(env)
>>> env.action_space
Box(-1.0, 1.0, (4,), float32)
>>> env.step(np.array([5.0, 2.0, -10.0, 0.0]))
# Executes the action np.array([1.0, 1.0, -1.0, 0]) in the base environment
"""
def __init__(self, env: gym.Env):
"""A wrapper for clipping continuous actions within the valid bound.
Args:
env: The environment to apply the wrapper
"""
assert isinstance(env.action_space, Box)
super().__init__(
env,
lambda action: jp.clip(action, env.action_space.low, env.action_space.high),
Box(
-np.inf,
np.inf,
shape=env.action_space.shape,
dtype=env.action_space.dtype,
),
)
class RescaleActionV0(LambdaActionV0):
"""Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action`
or :attr:`max_action` are numpy arrays, the shape must match the shape of the environment's action space.
Example:
>>> import gymnasium as gym
>>> import numpy as np
>>> env = gym.make('BipedalWalker-v3', disable_env_checker=True)
>>> _ = env.reset(seed=42)
>>> obs, _, _, _, _ = env.step(np.array([1,1,1,1]))
>>> _ = env.reset(seed=42)
>>> min_action = -0.5
>>> max_action = np.array([0.0, 0.5, 1.0, 0.75])
>>> wrapped_env = RescaleActionV0(env, min_action=min_action, max_action=max_action)
>>> wrapped_env_obs, _, _, _, _ = wrapped_env.step(max_action)
>>> np.alltrue(obs == wrapped_env_obs)
True
"""
def __init__(
self,
env: gym.Env,
min_action: float | int | np.ndarray,
max_action: float | int | np.ndarray,
):
"""Initializes the :class:`RescaleAction` wrapper.
Args:
env (Env): The environment to apply the wrapper
min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
"""
assert isinstance(env.action_space, Box)
assert not np.any(env.action_space.low == np.inf) and not np.any(
env.action_space.high == np.inf
)
if not isinstance(min_action, np.ndarray):
assert np.issubdtype(type(min_action), np.integer) or np.issubdtype(
type(max_action), np.floating
)
min_action = np.full(env.action_space.shape, min_action)
assert min_action.shape == env.action_space.shape
assert not np.any(min_action == np.inf)
if not isinstance(max_action, np.ndarray):
assert np.issubdtype(type(max_action), np.integer) or np.issubdtype(
type(max_action), np.floating
)
max_action = np.full(env.action_space.shape, max_action)
assert max_action.shape == env.action_space.shape
assert not np.any(max_action == np.inf)
assert isinstance(env.action_space, Box)
assert np.all(np.less_equal(min_action, max_action))
# Imagine the x-axis between the old Box and the y-axis being the new Box
gradient = (env.action_space.high - env.action_space.low) / (
max_action - min_action
)
intercept = gradient * -min_action + env.action_space.low
super().__init__(
env,
lambda action: gradient * action + intercept,
Box(
low=min_action,
high=max_action,
shape=env.action_space.shape,
dtype=env.action_space.dtype,
),
)