Files
Gymnasium/tests/wrappers/vector/test_vector_wrappers.py
James Mochizuki-Freeman 89bedf1f33 Fix Box casting and sampling edge-cases (#774)
Co-authored-by: Mark Towers <mark.m.towers@gmail.com>
2024-03-23 16:19:52 +00:00

116 lines
4.1 KiB
Python

"""Tests that the vectorised wrappers operate identically in `VectorEnv(Wrapper)` and `VectorWrapper(VectorEnv)`.
The exception is the data converter wrappers
* Data conversion wrappers - `JaxToTorch`, `JaxToNumpy` and `NumpyToJax`
* Normalizing wrappers - `NormalizeObservation` and `NormalizeReward`
* Different implementations - `LambdaObservation`, `LambdaReward` and `LambdaAction`
* Different random sources - `StickyAction`
"""
from __future__ import annotations
from typing import Any
import numpy as np
import pytest
import gymnasium as gym
from gymnasium import wrappers
from gymnasium.spaces import Box, Dict, Discrete
from gymnasium.utils.env_checker import data_equivalence
from gymnasium.vector import VectorEnv
from tests.testing_env import GenericTestEnv
@pytest.fixture
def custom_environments():
gym.register(
"CustomDictEnv-v0",
lambda: GenericTestEnv(
observation_space=Dict({"a": Box(0, 1), "b": Discrete(5)})
),
)
yield
del gym.registry["CustomDictEnv-v0"]
@pytest.mark.parametrize("num_envs", (1, 3))
@pytest.mark.parametrize(
"env_id, wrapper_name, kwargs",
(
("CustomDictEnv-v0", "FilterObservation", {"filter_keys": ["a"]}),
("CartPole-v1", "FlattenObservation", {}),
("CarRacing-v2", "GrayscaleObservation", {}),
("CarRacing-v2", "ResizeObservation", {"shape": (35, 45)}),
("CarRacing-v2", "ReshapeObservation", {"shape": (96, 48, 6)}),
("CartPole-v1", "RescaleObservation", {"min_obs": 0, "max_obs": 1}),
("CarRacing-v2", "DtypeObservation", {"dtype": np.int32}),
# ("CartPole-v1", "RenderObservation", {}), # not implemented
# ("CartPole-v1", "TimeAwareObservation", {}), # not implemented
# ("CartPole-v1", "FrameStackObservation", {}), # not implemented
# ("CartPole-v1", "DelayObservation", {}), # not implemented
("MountainCarContinuous-v0", "ClipAction", {}),
(
"MountainCarContinuous-v0",
"RescaleAction",
{"min_action": 1, "max_action": 2},
),
("CartPole-v1", "ClipReward", {"min_reward": -0.25, "max_reward": 0.75}),
),
)
def test_vector_wrapper_equivalence(
env_id: str,
wrapper_name: str,
kwargs: dict[str, Any],
num_envs: int,
custom_environments,
vectorization_mode: str = "sync",
num_steps: int = 50,
):
vector_wrapper = getattr(wrappers.vector, wrapper_name)
wrapper_vector_env: VectorEnv = vector_wrapper(
gym.make_vec(
id=env_id, num_envs=num_envs, vectorization_mode=vectorization_mode
),
**kwargs,
)
env_wrapper = getattr(wrappers, wrapper_name)
vector_wrapper_env = gym.make_vec(
id=env_id,
num_envs=num_envs,
vectorization_mode=vectorization_mode,
wrappers=(lambda env: env_wrapper(env, **kwargs),),
)
assert wrapper_vector_env.action_space == vector_wrapper_env.action_space
assert wrapper_vector_env.observation_space == vector_wrapper_env.observation_space
assert (
wrapper_vector_env.single_action_space == vector_wrapper_env.single_action_space
)
assert (
wrapper_vector_env.single_observation_space
== vector_wrapper_env.single_observation_space
)
assert wrapper_vector_env.num_envs == vector_wrapper_env.num_envs
wrapper_vector_obs, wrapper_vector_info = wrapper_vector_env.reset(seed=123)
vector_wrapper_obs, vector_wrapper_info = vector_wrapper_env.reset(seed=123)
assert data_equivalence(wrapper_vector_obs, vector_wrapper_obs)
assert data_equivalence(wrapper_vector_info, vector_wrapper_info)
for _ in range(num_steps):
action = wrapper_vector_env.action_space.sample()
wrapper_vector_step_returns = wrapper_vector_env.step(action)
vector_wrapper_step_returns = vector_wrapper_env.step(action)
for wrapper_vector_return, vector_wrapper_return in zip(
wrapper_vector_step_returns, vector_wrapper_step_returns
):
assert data_equivalence(wrapper_vector_return, vector_wrapper_return)
wrapper_vector_env.close()
vector_wrapper_env.close()