mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-23 15:04:20 +00:00
124 lines
4.5 KiB
Python
124 lines
4.5 KiB
Python
"""Tests that the vectorised wrappers operate identically in `VectorEnv(Wrapper)` and `VectorWrapper(VectorEnv)`.
|
|
|
|
The exception is the data converter wrappers (`JaxToTorch`, `JaxToNumpy` and `NumpyToJax`)
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
import gymnasium as gym
|
|
from gymnasium.experimental import wrappers
|
|
from gymnasium.experimental.vector import VectorEnv
|
|
from gymnasium.spaces import Box, Dict, Discrete
|
|
from gymnasium.utils.env_checker import data_equivalence
|
|
from tests.testing_env import GenericTestEnv
|
|
|
|
|
|
@pytest.fixture
|
|
def custom_environments():
|
|
gym.register(
|
|
"CustomDictEnv-v0",
|
|
lambda: GenericTestEnv(
|
|
observation_space=Dict({"a": Box(0, 1), "b": Discrete(5)})
|
|
),
|
|
)
|
|
|
|
yield
|
|
|
|
del gym.registry["CustomDictEnv-v0"]
|
|
|
|
|
|
@pytest.mark.parametrize("num_envs", (1, 3))
|
|
@pytest.mark.parametrize(
|
|
"env_id, wrapper_name, kwargs",
|
|
(
|
|
("CustomDictEnv-v0", "FilterObservationV0", {"filter_keys": ["a"]}),
|
|
("CartPole-v1", "FlattenObservationV0", {}),
|
|
("CarRacing-v2", "GrayscaleObservationV0", {}),
|
|
# ("CarRacing-v2", "ResizeObservationV0", {"shape": (35, 45)}),
|
|
("CarRacing-v2", "ReshapeObservationV0", {"shape": (96, 48, 6)}),
|
|
("CartPole-v1", "RescaleObservationV0", {"min_obs": 0, "max_obs": 1}),
|
|
("CartPole-v1", "DtypeObservationV0", {"dtype": np.int32}),
|
|
# ("CartPole-v1", "PixelObservationV0", {}),
|
|
# ("CartPole-v1", "NormalizeObservationV0", {}),
|
|
# ("CartPole-v1", "TimeAwareObservationV0", {}),
|
|
# ("CartPole-v1", "FrameStackObservationV0", {}),
|
|
# ("CartPole-v1", "DelayObservationV0", {}),
|
|
("MountainCarContinuous-v0", "ClipActionV0", {}),
|
|
(
|
|
"MountainCarContinuous-v0",
|
|
"RescaleActionV0",
|
|
{"min_action": 1, "max_action": 2},
|
|
),
|
|
# ("CartPole-v1", "StickyActionV0", {}),
|
|
("CartPole-v1", "ClipRewardV0", {"min_reward": 0.25, "max_reward": 0.75}),
|
|
# ("CartPole-v1", "NormalizeRewardV1", {}),
|
|
),
|
|
)
|
|
def test_vector_wrapper_equivalence(
|
|
env_id: str,
|
|
wrapper_name: str,
|
|
kwargs: dict[str, Any],
|
|
num_envs: int,
|
|
custom_environments,
|
|
vectorization_mode: str = "sync",
|
|
num_steps: int = 50,
|
|
):
|
|
vector_wrapper = getattr(wrappers.vector, wrapper_name)
|
|
wrapper_vector_env: VectorEnv = vector_wrapper(
|
|
gym.make_vec(
|
|
id=env_id, num_envs=num_envs, vectorization_mode=vectorization_mode
|
|
),
|
|
**kwargs,
|
|
)
|
|
env_wrapper = getattr(wrappers, wrapper_name)
|
|
vector_wrapper_env = gym.make_vec(
|
|
id=env_id,
|
|
num_envs=num_envs,
|
|
vectorization_mode=vectorization_mode,
|
|
wrappers=(lambda env: env_wrapper(env, **kwargs),),
|
|
)
|
|
|
|
assert wrapper_vector_env.action_space == vector_wrapper_env.action_space
|
|
assert wrapper_vector_env.observation_space == vector_wrapper_env.observation_space
|
|
assert (
|
|
wrapper_vector_env.single_action_space == vector_wrapper_env.single_action_space
|
|
)
|
|
assert (
|
|
wrapper_vector_env.single_observation_space
|
|
== vector_wrapper_env.single_observation_space
|
|
)
|
|
|
|
assert wrapper_vector_env.num_envs == vector_wrapper_env.num_envs
|
|
|
|
wrapper_vector_obs, wrapper_vector_info = wrapper_vector_env.reset(seed=123)
|
|
vector_wrapper_obs, vector_wrapper_info = vector_wrapper_env.reset(seed=123)
|
|
|
|
assert data_equivalence(wrapper_vector_obs, vector_wrapper_obs)
|
|
assert data_equivalence(wrapper_vector_info, vector_wrapper_info)
|
|
|
|
for _ in range(num_steps):
|
|
action = wrapper_vector_env.action_space.sample()
|
|
wrapper_vector_step_returns = wrapper_vector_env.step(action)
|
|
vector_wrapper_step_returns = vector_wrapper_env.step(action)
|
|
|
|
for wrapper_vector_return, vector_wrapper_return in zip(
|
|
wrapper_vector_step_returns, vector_wrapper_step_returns
|
|
):
|
|
assert data_equivalence(wrapper_vector_return, vector_wrapper_return)
|
|
|
|
wrapper_vector_env.close()
|
|
vector_wrapper_env.close()
|
|
|
|
|
|
# ("CartPole-v1", "LambdaObservationV0", {"func": lambda obs: obs + 1}),
|
|
# ("CartPole-v1", "LambdaActionV0", {"func": lambda action: action + 1}),
|
|
# ("CartPole-v1", "LambdaRewardV0", {"func": lambda reward: reward + 1}),
|
|
# (vector.JaxToNumpyV0, {}, {}),
|
|
# (vector.JaxToTorchV0, {}, {}),
|
|
# (vector.NumpyToTorchV0, {}, {}),
|
|
# ("CartPole-v1", "RecordEpisodeStatisticsV0", {}), # for the time taken in info, this is not equivalent for two instances
|