Misc bug fixes (#516)

This commit is contained in:
Mark Towers
2023-05-23 15:35:49 +01:00
committed by GitHub
parent e9d9515d51
commit 22a00c2a75
8 changed files with 104 additions and 58 deletions

View File

@@ -904,7 +904,6 @@ def make_vec(
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
num_envs: Number of environments to create
vectorization_mode: How to vectorize the environment. Can be either "async", "sync" or "custom"
kwargs: Additional arguments to pass to the environment constructor.
vector_kwargs: Additional arguments to pass to the vectorized environment constructor.
wrappers: A sequence of wrapper functions to apply to the environment. Can only be used in "sync" or "async" mode.
**kwargs: Additional arguments to pass to the environment constructor.
@@ -953,13 +952,15 @@ def make_vec(
def _create_env():
# Env creator for use with sync and async modes
_kwargs_copy = _kwargs.copy()
render_mode = _kwargs.get("render_mode", None)
if render_mode is not None:
inner_render_mode = (
render_mode[: -len("_list")]
if render_mode is not None and render_mode.endswith("_list")
if render_mode.endswith("_list")
else render_mode
)
_kwargs_copy = _kwargs.copy()
_kwargs_copy["render_mode"] = inner_render_mode
_env = env_creator(**_kwargs_copy)

View File

@@ -284,11 +284,18 @@ class VectorWrapper(VectorEnv):
# explicitly forward the methods defined in VectorEnv
# to self.env (instead of the base class)
def reset(self, **kwargs):
"""Reset all environments."""
return self.env.reset(**kwargs)
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Reset all environment using seed and options."""
return self.env.reset(seed=seed, options=options)
def step(self, actions):
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Step all environments."""
return self.env.step(actions)
@@ -301,7 +308,7 @@ class VectorWrapper(VectorEnv):
return self.env.close_extras(**kwargs)
# implicitly forward all other methods and attributes to self.env
def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
"""Forward all other attributes to the base environment."""
if name.startswith("_"):
raise AttributeError(f"attempted to get missing private attribute '{name}'")
@@ -382,17 +389,23 @@ class VectorWrapper(VectorEnv):
class VectorObservationWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the observation. Equivalent to :class:`gym.ObservationWrapper` for vectorized environments."""
def reset(self, **kwargs):
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
observation = self.env.reset(**kwargs)
return self.observation(observation)
obs, info = self.env.reset(seed=seed, options=options)
return self.observation(obs), info
def step(self, actions):
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
observation, reward, termination, truncation, info = self.env.step(actions)
return (
self.observation(observation),
observation,
reward,
termination,
truncation,
@@ -414,9 +427,11 @@ class VectorObservationWrapper(VectorWrapper):
class VectorActionWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the actions. Equivalent of :class:`~gym.ActionWrapper` for vectorized environments."""
def step(self, actions: ActType):
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Steps through the environment using a modified action by :meth:`action`."""
return self.env.step(self.action(actions))
return self.env.step(self.actions(actions))
def actions(self, actions: ActType) -> ActType:
"""Transform the actions before sending them to the environment.
@@ -433,7 +448,9 @@ class VectorActionWrapper(VectorWrapper):
class VectorRewardWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the reward. Equivalent of :class:`~gym.RewardWrapper` for vectorized environments."""
def step(self, actions):
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Steps through the environment returning a reward modified by :meth:`reward`."""
observation, reward, termination, truncation, info = self.env.step(actions)
return observation, self.reward(reward), termination, truncation, info

View File

@@ -117,8 +117,10 @@ def __getattr__(wrapper_name: str):
AttributeError: If the wrapper does not exist.
DeprecatedWrapper: If the version is not the latest.
"""
if wrapper_name == "vector":
return importlib.import_module("gymnasium.experimental.wrappers.vector")
# Check if the requested wrapper is in the _wrapper_to_class dictionary
if wrapper_name in _wrapper_to_class:
elif wrapper_name in _wrapper_to_class:
import_stmt = (
f"gymnasium.experimental.wrappers.{_wrapper_to_class[wrapper_name]}"
)

View File

@@ -16,9 +16,7 @@ def test_flatten_observation_wrapper():
reset_func=record_random_obs_reset,
step_func=record_random_obs_step,
)
print(env.observation_space)
wrapped_env = FlattenObservationV0(env)
print(wrapped_env.observation_space)
obs, info = wrapped_env.reset()
check_obs(env, wrapped_env, obs, info["obs"])

View File

@@ -4,7 +4,12 @@ import re
import pytest
import gymnasium
import gymnasium.experimental.wrappers as wrappers
from gymnasium.experimental.wrappers import (
_wrapper_to_class, # pyright: ignore[reportPrivateUsage]
)
from gymnasium.experimental.wrappers import __all__
def test_import_wrappers():
@@ -41,3 +46,24 @@ def test_import_wrappers():
),
):
getattr(wrappers, "NonexistentWrapper")
def test_all_wrapper_shorten():
"""Test that all wrappers in `__all__` are contained within the `_wrapper_to_class` conversion."""
all_wrappers = set(__all__)
all_wrappers.remove("vector")
assert all_wrappers == set(_wrapper_to_class.keys())
@pytest.mark.parametrize("wrapper_name", __all__)
def test_all_wrappers_shortened(wrapper_name):
"""Check that each element of the `__all__` wrappers can be loaded, provided dependencies are installed."""
if wrapper_name != "vector":
try:
assert getattr(gymnasium.experimental.wrappers, wrapper_name) is not None
except gymnasium.error.DependencyNotInstalled as e:
pytest.skip(str(e))
def test_wrapper_vector():
assert gymnasium.experimental.wrappers.vector is not None

View File

@@ -1,26 +0,0 @@
"""Tests that all shortened imports for wrappers all work."""
import pytest
import gymnasium
from gymnasium.experimental.wrappers import (
_wrapper_to_class, # pyright: ignore[reportPrivateUsage]
)
from gymnasium.experimental.wrappers import __all__
def test_all_wrapper_shorten():
"""Test that all wrappers in `__all__` are contained within the `_wrapper_to_class` conversion."""
all_wrappers = set(__all__)
all_wrappers.remove("vector")
assert all_wrappers == set(_wrapper_to_class.keys())
@pytest.mark.parametrize("wrapper_name", __all__)
def test_all_wrappers_shortened(wrapper_name):
"""Check that each element of the `__all__` wrappers can be loaded, provided dependencies are installed."""
if wrapper_name != "vector":
try:
assert getattr(gymnasium.experimental.wrappers, wrapper_name) is not None
except gymnasium.error.DependencyNotInstalled as e:
pytest.skip(str(e))

View File

@@ -1,6 +1,10 @@
"""Test suite for ResizeObservationV0."""
import numpy as np
from __future__ import annotations
import numpy as np
import pytest
import gymnasium as gym
from gymnasium.experimental.wrappers import ResizeObservationV0
from gymnasium.spaces import Box
from tests.experimental.wrappers.utils import (
@@ -11,17 +15,42 @@ from tests.experimental.wrappers.utils import (
from tests.testing_env import GenericTestEnv
def test_resize_observation_wrapper():
"""Test the ``ResizeObservation`` that the observation has changed size."""
env = GenericTestEnv(
@pytest.mark.parametrize(
"env",
(
GenericTestEnv(
observation_space=Box(0, 255, shape=(60, 60, 3), dtype=np.uint8),
reset_func=record_random_obs_reset,
step_func=record_random_obs_step,
)
),
GenericTestEnv(
observation_space=Box(0, 255, shape=(60, 60), dtype=np.uint8),
reset_func=record_random_obs_reset,
step_func=record_random_obs_step,
),
),
)
def test_resize_observation_wrapper(env):
"""Test the ``ResizeObservation`` that the observation has changed size."""
wrapped_env = ResizeObservationV0(env, (25, 25))
assert wrapped_env.observation_space.shape[:2] == (25, 25)
obs, info = wrapped_env.reset()
check_obs(env, wrapped_env, obs, info["obs"])
obs, _, _, _, info = wrapped_env.step(None)
check_obs(env, wrapped_env, obs, info["obs"])
@pytest.mark.parametrize("shape", ((10, 10), (20, 20), (60, 60), (100, 100)))
def test_resize_shapes(shape: tuple[int, int]):
env = ResizeObservationV0(gym.make("CarRacing-v2"), shape)
assert env.observation_space == Box(
low=0, high=255, shape=shape + (3,), dtype=np.uint8
)
obs, info = env.reset()
assert obs in env.observation_space
obs, _, _, _, _ = env.step(env.action_space.sample())
assert obs in env.observation_space

View File

@@ -382,7 +382,6 @@ def test_space_sample_mask(space: Space, mask, n_trials: int = 100):
expected_frequency = (
np.ones(space.shape) * np.where(mask == 2, 0.5, mask) * n_trials
)
print(expected_frequency)
observed_frequency = np.sum(samples, axis=0)
assert space.shape == expected_frequency.shape == observed_frequency.shape