Files
Gymnasium/tests/experimental/wrappers/vector/test_vector_wrappers.py

124 lines
4.5 KiB
Python

"""Tests that the vectorised wrappers operate identically in `VectorEnv(Wrapper)` and `VectorWrapper(VectorEnv)`.
The exception is the data converter wrappers (`JaxToTorch`, `JaxToNumpy` and `NumpyToJax`)
"""
from __future__ import annotations
from typing import Any
import numpy as np
import pytest
import gymnasium as gym
from gymnasium.experimental import wrappers
from gymnasium.experimental.vector import VectorEnv
from gymnasium.spaces import Box, Dict, Discrete
from gymnasium.utils.env_checker import data_equivalence
from tests.testing_env import GenericTestEnv
@pytest.fixture
def custom_environments():
gym.register(
"CustomDictEnv-v0",
lambda: GenericTestEnv(
observation_space=Dict({"a": Box(0, 1), "b": Discrete(5)})
),
)
yield
del gym.registry["CustomDictEnv-v0"]
@pytest.mark.parametrize("num_envs", (1, 3))
@pytest.mark.parametrize(
"env_id, wrapper_name, kwargs",
(
("CustomDictEnv-v0", "FilterObservationV0", {"filter_keys": ["a"]}),
("CartPole-v1", "FlattenObservationV0", {}),
("CarRacing-v2", "GrayscaleObservationV0", {}),
# ("CarRacing-v2", "ResizeObservationV0", {"shape": (35, 45)}),
("CarRacing-v2", "ReshapeObservationV0", {"shape": (96, 48, 6)}),
("CartPole-v1", "RescaleObservationV0", {"min_obs": 0, "max_obs": 1}),
("CartPole-v1", "DtypeObservationV0", {"dtype": np.int32}),
# ("CartPole-v1", "PixelObservationV0", {}),
# ("CartPole-v1", "NormalizeObservationV0", {}),
# ("CartPole-v1", "TimeAwareObservationV0", {}),
# ("CartPole-v1", "FrameStackObservationV0", {}),
# ("CartPole-v1", "DelayObservationV0", {}),
("MountainCarContinuous-v0", "ClipActionV0", {}),
(
"MountainCarContinuous-v0",
"RescaleActionV0",
{"min_action": 1, "max_action": 2},
),
# ("CartPole-v1", "StickyActionV0", {}),
("CartPole-v1", "ClipRewardV0", {"min_reward": 0.25, "max_reward": 0.75}),
# ("CartPole-v1", "NormalizeRewardV1", {}),
),
)
def test_vector_wrapper_equivalence(
env_id: str,
wrapper_name: str,
kwargs: dict[str, Any],
num_envs: int,
custom_environments,
vectorization_mode: str = "sync",
num_steps: int = 50,
):
vector_wrapper = getattr(wrappers.vector, wrapper_name)
wrapper_vector_env: VectorEnv = vector_wrapper(
gym.make_vec(
id=env_id, num_envs=num_envs, vectorization_mode=vectorization_mode
),
**kwargs,
)
env_wrapper = getattr(wrappers, wrapper_name)
vector_wrapper_env = gym.make_vec(
id=env_id,
num_envs=num_envs,
vectorization_mode=vectorization_mode,
wrappers=(lambda env: env_wrapper(env, **kwargs),),
)
assert wrapper_vector_env.action_space == vector_wrapper_env.action_space
assert wrapper_vector_env.observation_space == vector_wrapper_env.observation_space
assert (
wrapper_vector_env.single_action_space == vector_wrapper_env.single_action_space
)
assert (
wrapper_vector_env.single_observation_space
== vector_wrapper_env.single_observation_space
)
assert wrapper_vector_env.num_envs == vector_wrapper_env.num_envs
wrapper_vector_obs, wrapper_vector_info = wrapper_vector_env.reset(seed=123)
vector_wrapper_obs, vector_wrapper_info = vector_wrapper_env.reset(seed=123)
assert data_equivalence(wrapper_vector_obs, vector_wrapper_obs)
assert data_equivalence(wrapper_vector_info, vector_wrapper_info)
for _ in range(num_steps):
action = wrapper_vector_env.action_space.sample()
wrapper_vector_step_returns = wrapper_vector_env.step(action)
vector_wrapper_step_returns = vector_wrapper_env.step(action)
for wrapper_vector_return, vector_wrapper_return in zip(
wrapper_vector_step_returns, vector_wrapper_step_returns
):
assert data_equivalence(wrapper_vector_return, vector_wrapper_return)
wrapper_vector_env.close()
vector_wrapper_env.close()
# ("CartPole-v1", "LambdaObservationV0", {"func": lambda obs: obs + 1}),
# ("CartPole-v1", "LambdaActionV0", {"func": lambda action: action + 1}),
# ("CartPole-v1", "LambdaRewardV0", {"func": lambda reward: reward + 1}),
# (vector.JaxToNumpyV0, {}, {}),
# (vector.JaxToTorchV0, {}, {}),
# (vector.NumpyToTorchV0, {}, {}),
# ("CartPole-v1", "RecordEpisodeStatisticsV0", {}), # for the time taken in info, this is not equivalent for two instances