From ed8b13e11afe82b1352fb84a9d8b29331ec4f654 Mon Sep 17 00:00:00 2001 From: Xingdong Zuo Date: Fri, 11 Oct 2019 23:58:04 +0200 Subject: [PATCH] [Wrappers]: add TransformObservation (#1670) * Create transform_observation.py * Create test_transform_observation.py * Update __init__.py --- gym/wrappers/__init__.py | 1 + gym/wrappers/test_transform_observation.py | 27 ++++++++++++++++++++++ gym/wrappers/transform_observation.py | 26 +++++++++++++++++++++ 3 files changed, 54 insertions(+) create mode 100644 gym/wrappers/test_transform_observation.py create mode 100644 gym/wrappers/transform_observation.py diff --git a/gym/wrappers/__init__.py b/gym/wrappers/__init__.py index 5bc7c6886..8e9e80285 100644 --- a/gym/wrappers/__init__.py +++ b/gym/wrappers/__init__.py @@ -8,6 +8,7 @@ from gym.wrappers.flatten_observation import FlattenObservation from gym.wrappers.gray_scale_observation import GrayScaleObservation from gym.wrappers.frame_stack import LazyFrames from gym.wrappers.frame_stack import FrameStack +from gym.wrappers.transform_observation import TransformObservation from gym.wrappers.transform_reward import TransformReward from gym.wrappers.resize_observation import ResizeObservation from gym.wrappers.clip_action import ClipAction diff --git a/gym/wrappers/test_transform_observation.py b/gym/wrappers/test_transform_observation.py new file mode 100644 index 000000000..07eecd7eb --- /dev/null +++ b/gym/wrappers/test_transform_observation.py @@ -0,0 +1,27 @@ +import pytest + +import numpy as np + +import gym +from gym.wrappers import TransformObservation + + +@pytest.mark.parametrize('env_id', ['CartPole-v1', 'Pendulum-v0']) +def test_transform_observation(env_id): + affine_transform = lambda x: 3*x + 2 + env = gym.make(env_id) + wrapped_env = TransformObservation(gym.make(env_id), lambda obs: affine_transform(obs)) + + env.seed(0) + wrapped_env.seed(0) + + obs = env.reset() + wrapped_obs = wrapped_env.reset() + assert np.allclose(wrapped_obs, affine_transform(obs)) + + action = env.action_space.sample() + obs, reward, done, _ = env.step(action) + wrapped_obs, wrapped_reward, wrapped_done, _ = wrapped_env.step(action) + assert np.allclose(wrapped_obs, affine_transform(obs)) + assert np.allclose(wrapped_reward, reward) + assert wrapped_done == done diff --git a/gym/wrappers/transform_observation.py b/gym/wrappers/transform_observation.py new file mode 100644 index 000000000..6b15a59bd --- /dev/null +++ b/gym/wrappers/transform_observation.py @@ -0,0 +1,26 @@ +from gym import ObservationWrapper + + +class TransformObservation(ObservationWrapper): + r"""Transform the observation via an arbitrary function. + + Example:: + + >>> import gym + >>> env = gym.make('CartPole-v1') + >>> env = TransformObservation(env, lambda obs: obs + 0.1*np.random.randn(*obs.shape)) + >>> env.reset() + array([-0.08319338, 0.04635121, -0.07394746, 0.20877492]) + + Args: + env (Env): environment + f (callable): a function that transforms the observation + + """ + def __init__(self, env, f): + super(TransformObservation, self).__init__(env) + assert callable(f) + self.f = f + + def observation(self, observation): + return self.f(observation)