Files
Gymnasium/gymnasium/experimental/wrappers/lambda_observations.py

31 lines
847 B
Python
Raw Normal View History

2022-11-20 00:57:10 +01:00
"""Lambda observation wrappers which apply a function to the observation."""
from typing import Any, Callable
import gymnasium as gym
from gymnasium.core import ObsType
from gymnasium.experimental.wrappers import ArgType
2022-11-20 00:57:10 +01:00
class LambdaObservationV0(gym.ObservationWrapper):
2022-11-20 00:57:10 +01:00
"""Lambda observation wrapper where a function is provided that is applied to the observation."""
def __init__(
self,
env: gym.Env,
func: Callable[[ArgType], Any],
):
"""Constructor for the lambda observation wrapper.
Args:
env: The environment to wrap
func: A function that takes
"""
super().__init__(env)
self.func = func
def observation(self, observation: ObsType) -> Any:
"""Apply function to the observation."""
return self.func(observation)