Files
Gymnasium/gymnasium/dev_wrappers/lambda_reward.py

87 lines
2.6 KiB
Python
Raw Normal View History

2022-11-20 00:57:10 +01:00
"""Lambda reward wrappers which apply a function to the reward."""
from typing import Any, Callable, Optional, Union
import numpy as np
import gymnasium as gym
from gymnasium.dev_wrappers import ArgType
from gymnasium.error import InvalidBound
class LambdaRewardV0(gym.RewardWrapper):
"""A reward wrapper that allows a custom function to modify the step reward.
Example:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import LambdaRewardV0
>>> env = gym.make("CartPole-v1")
>>> env = LambdaRewardV0(env, lambda r: 2 * r + 1)
>>> _ = env.reset()
>>> _, rew, _, _, _ = env.step(0)
>>> rew
3.0
"""
def __init__(
self,
env: gym.Env,
func: Callable[[ArgType], Any],
):
"""Initialize LambdaRewardV0 wrapper.
Args:
env (Env): The environment to apply the wrapper
func: (Callable): The function to apply to reward
"""
super().__init__(env)
self.func = func
def reward(self, reward: Union[float, int, np.ndarray]) -> Any:
"""Apply function to reward.
Args:
reward (Union[float, int, np.ndarray]): environment's reward
"""
return self.func(reward)
class ClipRewardsV0(LambdaRewardV0):
"""A wrapper that clips the rewards for an environment between an upper and lower bound.
Example with an upper and lower bound:
>>> import gymnasium as gym
>>> from gymnasium.wrappers import ClipRewardsV0
>>> env = gym.make("CartPole-v1")
>>> env = ClipRewardsV0(env, 0, 0.5)
>>> env.reset()
>>> _, rew, _, _, _ = env.step(1)
>>> rew
0.5
"""
def __init__(
self,
env: gym.Env,
min_reward: Optional[Union[float, np.ndarray]] = None,
max_reward: Optional[Union[float, np.ndarray]] = None,
):
"""Initialize ClipRewardsV0 wrapper.
Args:
env (Env): The environment to apply the wrapper
min_reward (Union[float, np.ndarray]): lower bound to apply
max_reward (Union[float, np.ndarray]): higher bound to apply
"""
if min_reward is None and max_reward is None:
raise InvalidBound("Both `min_reward` and `max_reward` cannot be None")
elif max_reward is not None and min_reward is not None:
if np.any(max_reward - min_reward < 0):
raise InvalidBound(
f"Min reward ({min_reward}) must be smaller than max reward ({max_reward})"
)
super().__init__(env, lambda x: np.clip(x, a_min=min_reward, a_max=max_reward))