Files
Gymnasium/tests/experimental/wrappers/test_lambda_observation.py

60 lines
1.6 KiB
Python
Raw Normal View History

2022-11-20 00:57:10 +01:00
"""Test suite for LambdaObservationV0."""
import numpy as np
import gymnasium as gym
from gymnasium.experimental.wrappers import LambdaObservationV0
2022-11-20 00:57:10 +01:00
from gymnasium.spaces import Box
2022-11-20 00:57:10 +01:00
NUM_ENVS = 3
BOX_SPACE = Box(-5, 5, (1,), dtype=np.float64)
SEED = 42
DISCRETE_ACTION = 1
def test_lambda_observation_v0():
"""Tests lambda observation.
Tests if function is correctly applied to environment's observation.
"""
env = gym.make("CartPole-v1")
env.reset(seed=SEED)
obs, _, _, _, _ = env.step(DISCRETE_ACTION)
observation_shift = 1
env.reset(seed=SEED)
wrapped_env = LambdaObservationV0(
2022-12-03 19:45:39 +00:00
env, lambda observation: observation + observation_shift, None
2022-11-20 00:57:10 +01:00
)
wrapped_obs, _, _, _, _ = wrapped_env.step(DISCRETE_ACTION)
assert np.alltrue(wrapped_obs == obs + observation_shift)
def test_lambda_observation_v0_within_vector():
"""Tests lambda observation in vectorized environments.
Tests if function is correctly applied to environment's observation
in vectorized environment.
"""
env = gym.vector.make(
"CarRacing-v2", continuous=False, num_envs=NUM_ENVS, asynchronous=False
)
env.reset(seed=SEED)
obs, _, _, _, _ = env.step(np.array([DISCRETE_ACTION for _ in range(NUM_ENVS)]))
observation_shift = 1
env.reset(seed=SEED)
wrapped_env = LambdaObservationV0(
2022-12-03 19:45:39 +00:00
env, lambda observation: observation + observation_shift, None
2022-11-20 00:57:10 +01:00
)
wrapped_obs, _, _, _, _ = wrapped_env.step(
np.array([DISCRETE_ACTION for _ in range(NUM_ENVS)])
)
assert np.alltrue(wrapped_obs == obs + observation_shift)