Files
Gymnasium/gym/wrappers/transform_observation.py

28 lines
741 B
Python
Raw Normal View History

from gym import ObservationWrapper
class TransformObservation(ObservationWrapper):
2021-07-29 02:26:34 +02:00
r"""Transform the observation via an arbitrary function.
Example::
>>> import gym
>>> env = gym.make('CartPole-v1')
>>> env = TransformObservation(env, lambda obs: obs + 0.1*np.random.randn(*obs.shape))
>>> env.reset()
array([-0.08319338, 0.04635121, -0.07394746, 0.20877492])
Args:
env (Env): environment
f (callable): a function that transforms the observation
"""
2021-07-29 02:26:34 +02:00
def __init__(self, env, f):
super(TransformObservation, self).__init__(env)
assert callable(f)
self.f = f
def observation(self, observation):
return self.f(observation)