[Wrappers]: RescaleAction (#1491)

* Create normalize_action.py

* Update __init__.py

* Create test_normalize_action.py

* Update normalize_action.py

* Update normalize_action.py

* Rename normalize_action.py to rescale_action.py

* Update __init__.py

* Update rescale_action.py

* Update and rename test_normalize_action.py to test_rescale_action.py

* Update test_rescale_action.py

* Update rescale_action.py

* Update rescale_action.py

* Update gym/wrappers/rescale_action.py

Thanks a lot @hartikainen !

Co-Authored-By: Kristian Hartikainen <kristian.hartikainen@gmail.com>

* Update gym/wrappers/rescale_action.py

That's a very clean way ! Thanks !

Co-Authored-By: Kristian Hartikainen <kristian.hartikainen@gmail.com>

* Update gym/wrappers/rescale_action.py

Co-Authored-By: Kristian Hartikainen <kristian.hartikainen@gmail.com>

* Update rescale_action.py

* Update gym/wrappers/rescale_action.py

Co-Authored-By: Kristian Hartikainen <kristian.hartikainen@gmail.com>

* Update rescale_action.py

* Update rescale_action.py
This commit is contained in:
Xingdong Zuo
2019-10-25 23:20:53 +02:00
committed by pzhokhov
parent b6b060036b
commit 1a5c786ef9
3 changed files with 65 additions and 0 deletions

View File

@@ -3,6 +3,7 @@ from gym.wrappers.monitor import Monitor
from gym.wrappers.time_limit import TimeLimit
from gym.wrappers.filter_observation import FilterObservation
from gym.wrappers.atari_preprocessing import AtariPreprocessing
from gym.wrappers.rescale_action import RescaleAction
from gym.wrappers.flatten_observation import FlattenObservation
from gym.wrappers.gray_scale_observation import GrayScaleObservation
from gym.wrappers.frame_stack import LazyFrames

View File

@@ -0,0 +1,32 @@
import numpy as np
import gym
from gym import spaces
class RescaleAction(gym.ActionWrapper):
r"""Rescales the continuous action space of the environment to a range [a,b].
Example::
>>> RescaleAction(env, a, b).action_space == Box(a,b)
True
"""
def __init__(self, env, a, b):
assert isinstance(env.action_space, spaces.Box), (
"expected Box action space, got {}".format(type(env.action_space)))
assert np.less_equal(a, b).all(), (a, b)
super(RescaleAction, self).__init__(env)
self.a = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + a
self.b = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + b
self.action_space = spaces.Box(low=a, high=b, shape=env.action_space.shape, dtype=env.action_space.dtype)
def action(self, action):
assert np.all(np.greater_equal(action, self.a)), (action, self.a)
assert np.all(np.less_equal(action, self.b)), (action, self.b)
low = self.env.action_space.low
high = self.env.action_space.high
action = low + (high - low)*((action - self.a)/(self.b - self.a))
action = np.clip(action, low, high)
return action

View File

@@ -0,0 +1,32 @@
import pytest
import numpy as np
import gym
from gym.wrappers import RescaleAction
def test_rescale_action():
env = gym.make('CartPole-v1')
with pytest.raises(AssertionError):
env = RescaleAction(env, -1, 1)
del env
env = gym.make('Pendulum-v0')
wrapped_env = RescaleAction(gym.make('Pendulum-v0'), -1, 1)
seed = 0
env.seed(seed)
wrapped_env.seed(seed)
obs = env.reset()
wrapped_obs = wrapped_env.reset()
assert np.allclose(obs, wrapped_obs)
obs, reward, _, _ = env.step([1.5])
with pytest.raises(AssertionError):
wrapped_env.step([1.5])
wrapped_obs, wrapped_reward, _, _ = wrapped_env.step([0.75])
assert np.allclose(obs, wrapped_obs)
assert np.allclose(reward, wrapped_reward)