mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-31 18:12:53 +00:00
[Wrappers]: RescaleAction (#1491)
* Create normalize_action.py * Update __init__.py * Create test_normalize_action.py * Update normalize_action.py * Update normalize_action.py * Rename normalize_action.py to rescale_action.py * Update __init__.py * Update rescale_action.py * Update and rename test_normalize_action.py to test_rescale_action.py * Update test_rescale_action.py * Update rescale_action.py * Update rescale_action.py * Update gym/wrappers/rescale_action.py Thanks a lot @hartikainen ! Co-Authored-By: Kristian Hartikainen <kristian.hartikainen@gmail.com> * Update gym/wrappers/rescale_action.py That's a very clean way ! Thanks ! Co-Authored-By: Kristian Hartikainen <kristian.hartikainen@gmail.com> * Update gym/wrappers/rescale_action.py Co-Authored-By: Kristian Hartikainen <kristian.hartikainen@gmail.com> * Update rescale_action.py * Update gym/wrappers/rescale_action.py Co-Authored-By: Kristian Hartikainen <kristian.hartikainen@gmail.com> * Update rescale_action.py * Update rescale_action.py
This commit is contained in:
@@ -3,6 +3,7 @@ from gym.wrappers.monitor import Monitor
|
||||
from gym.wrappers.time_limit import TimeLimit
|
||||
from gym.wrappers.filter_observation import FilterObservation
|
||||
from gym.wrappers.atari_preprocessing import AtariPreprocessing
|
||||
from gym.wrappers.rescale_action import RescaleAction
|
||||
from gym.wrappers.flatten_observation import FlattenObservation
|
||||
from gym.wrappers.gray_scale_observation import GrayScaleObservation
|
||||
from gym.wrappers.frame_stack import LazyFrames
|
||||
|
32
gym/wrappers/rescale_action.py
Normal file
32
gym/wrappers/rescale_action.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import numpy as np
|
||||
|
||||
import gym
|
||||
from gym import spaces
|
||||
|
||||
|
||||
class RescaleAction(gym.ActionWrapper):
|
||||
r"""Rescales the continuous action space of the environment to a range [a,b].
|
||||
|
||||
Example::
|
||||
|
||||
>>> RescaleAction(env, a, b).action_space == Box(a,b)
|
||||
True
|
||||
|
||||
"""
|
||||
def __init__(self, env, a, b):
|
||||
assert isinstance(env.action_space, spaces.Box), (
|
||||
"expected Box action space, got {}".format(type(env.action_space)))
|
||||
assert np.less_equal(a, b).all(), (a, b)
|
||||
super(RescaleAction, self).__init__(env)
|
||||
self.a = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + a
|
||||
self.b = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + b
|
||||
self.action_space = spaces.Box(low=a, high=b, shape=env.action_space.shape, dtype=env.action_space.dtype)
|
||||
|
||||
def action(self, action):
|
||||
assert np.all(np.greater_equal(action, self.a)), (action, self.a)
|
||||
assert np.all(np.less_equal(action, self.b)), (action, self.b)
|
||||
low = self.env.action_space.low
|
||||
high = self.env.action_space.high
|
||||
action = low + (high - low)*((action - self.a)/(self.b - self.a))
|
||||
action = np.clip(action, low, high)
|
||||
return action
|
32
gym/wrappers/test_rescale_action.py
Normal file
32
gym/wrappers/test_rescale_action.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import pytest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gym
|
||||
from gym.wrappers import RescaleAction
|
||||
|
||||
|
||||
def test_rescale_action():
|
||||
env = gym.make('CartPole-v1')
|
||||
with pytest.raises(AssertionError):
|
||||
env = RescaleAction(env, -1, 1)
|
||||
del env
|
||||
|
||||
env = gym.make('Pendulum-v0')
|
||||
wrapped_env = RescaleAction(gym.make('Pendulum-v0'), -1, 1)
|
||||
|
||||
seed = 0
|
||||
env.seed(seed)
|
||||
wrapped_env.seed(seed)
|
||||
|
||||
obs = env.reset()
|
||||
wrapped_obs = wrapped_env.reset()
|
||||
assert np.allclose(obs, wrapped_obs)
|
||||
|
||||
obs, reward, _, _ = env.step([1.5])
|
||||
with pytest.raises(AssertionError):
|
||||
wrapped_env.step([1.5])
|
||||
wrapped_obs, wrapped_reward, _, _ = wrapped_env.step([0.75])
|
||||
|
||||
assert np.allclose(obs, wrapped_obs)
|
||||
assert np.allclose(reward, wrapped_reward)
|
Reference in New Issue
Block a user