diff --git a/gymnasium/wrappers/stateful_observation.py b/gymnasium/wrappers/stateful_observation.py index 9414ae17f..d593dc720 100644 --- a/gymnasium/wrappers/stateful_observation.py +++ b/gymnasium/wrappers/stateful_observation.py @@ -487,6 +487,7 @@ class NormalizeObservation( Change logs: * v0.21.0 - Initially add * v1.0.0 - Add `update_running_mean` attribute to allow disabling of updating the running mean / standard, particularly useful for evaluation time. + Casts all observations to `np.float32` and sets the observation space with low/high of `-np.inf` and `np.inf` and dtype as `np.float32` """ def __init__(self, env: gym.Env[ObsType, ActType], epsilon: float = 1e-8): @@ -499,6 +500,14 @@ class NormalizeObservation( gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon) gym.ObservationWrapper.__init__(self, env) + assert env.observation_space.shape is not None + self.observation_space = gym.spaces.Box( + low=-np.inf, + high=np.inf, + shape=env.observation_space.shape, + dtype=np.float32, + ) + self.obs_rms = RunningMeanStd( shape=self.observation_space.shape, dtype=self.observation_space.dtype ) @@ -519,8 +528,8 @@ class NormalizeObservation( """Normalises the observation using the running mean and variance of the observations.""" if self._update_running_mean: self.obs_rms.update(np.array([observation])) - return (observation - self.obs_rms.mean) / np.sqrt( - self.obs_rms.var + self.epsilon + return np.float32( + (observation - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon) ) diff --git a/tests/wrappers/test_normalize_observation.py b/tests/wrappers/test_normalize_observation.py index 0fff7a995..9639d4100 100644 --- a/tests/wrappers/test_normalize_observation.py +++ b/tests/wrappers/test_normalize_observation.py @@ -1,6 +1,7 @@ """Test suite for NormalizeObservation wrapper.""" import numpy as np +import gymnasium as gym from gymnasium import spaces, wrappers from gymnasium.wrappers import NormalizeObservation from tests.testing_env import GenericTestEnv @@ -62,3 +63,14 @@ def test_update_running_mean_property(): wrapped_env.step(None) assert rms_var_updated == wrapped_env.obs_rms.var assert rms_mean_updated == wrapped_env.obs_rms.mean + + +def test_normalize_obs_with_vector(): + def thunk(): + env = gym.make("CarRacing-v2") + env = gym.wrappers.GrayscaleObservation(env) + env = gym.wrappers.NormalizeObservation(env) + return env + + envs = gym.vector.SyncVectorEnv([thunk for _ in range(4)]) + obs, _ = envs.reset()