Files
Gymnasium/tests/wrappers/test_normalize_reward.py

84 lines
2.6 KiB
Python
Raw Normal View History

"""Test suite for NormalizeReward wrapper."""
import numpy as np
import gymnasium as gym
from gymnasium.core import ActType
from gymnasium.wrappers import NormalizeReward
from tests.testing_env import GenericTestEnv
def constant_reward_step_func(self, action: ActType):
return self.observation_space.sample(), 1.0, False, False, {}
def test_running_mean_normalize_reward_wrapper():
"""Tests that the property `_update_running_mean` freezes/continues the running statistics updating."""
env = GenericTestEnv(step_func=constant_reward_step_func)
wrapped_env = NormalizeReward(env)
# Default value is True
assert wrapped_env.update_running_mean
wrapped_env.reset()
rms_var_init = wrapped_env.return_rms.var
rms_mean_init = wrapped_env.return_rms.mean
# Statistics are updated when env.step()
wrapped_env.step(None)
rms_var_updated = wrapped_env.return_rms.var
rms_mean_updated = wrapped_env.return_rms.mean
assert rms_var_init != rms_var_updated
assert rms_mean_init != rms_mean_updated
# Assure property is set
wrapped_env.update_running_mean = False
assert not wrapped_env.update_running_mean
# Statistics are frozen
wrapped_env.step(None)
assert rms_var_updated == wrapped_env.return_rms.var
assert rms_mean_updated == wrapped_env.return_rms.mean
def test_normalize_reward_wrapper():
"""Tests that the NormalizeReward does not throw an error."""
# TODO: Functional correctness should be tested
env = GenericTestEnv(step_func=constant_reward_step_func)
wrapped_env = NormalizeReward(env)
wrapped_env.reset()
_, reward, _, _, _ = wrapped_env.step(None)
assert np.ndim(reward) == 0
env.close()
def reward_reset_func(self: gym.Env, seed=None, options=None):
self.rewards = [0, 1, 2, 3, 4]
reward = self.rewards.pop(0)
return np.array([reward]), {"reward": reward}
def reward_step_func(self: gym.Env, action):
reward = self.rewards.pop(0)
return np.array([reward]), reward, len(self.rewards) == 0, False, {"reward": reward}
def test_normalize_return():
env = GenericTestEnv(reset_func=reward_reset_func, step_func=reward_step_func)
env = NormalizeReward(env)
env.reset()
env.step(env.action_space.sample())
np.testing.assert_almost_equal(
env.return_rms.mean,
np.mean([1]), # [first return]
decimal=4,
)
env.step(env.action_space.sample())
np.testing.assert_almost_equal(
env.return_rms.mean,
np.mean([2 + 1 * env.gamma, 1]), # [second return, first return]
decimal=4,
)