Misc bug fixes (#516)

This commit is contained in:
Mark Towers
2023-05-23 15:35:49 +01:00
committed by GitHub
parent e9d9515d51
commit 22a00c2a75
8 changed files with 104 additions and 58 deletions

View File

@@ -904,7 +904,6 @@ def make_vec(
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0' id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
num_envs: Number of environments to create num_envs: Number of environments to create
vectorization_mode: How to vectorize the environment. Can be either "async", "sync" or "custom" vectorization_mode: How to vectorize the environment. Can be either "async", "sync" or "custom"
kwargs: Additional arguments to pass to the environment constructor.
vector_kwargs: Additional arguments to pass to the vectorized environment constructor. vector_kwargs: Additional arguments to pass to the vectorized environment constructor.
wrappers: A sequence of wrapper functions to apply to the environment. Can only be used in "sync" or "async" mode. wrappers: A sequence of wrapper functions to apply to the environment. Can only be used in "sync" or "async" mode.
**kwargs: Additional arguments to pass to the environment constructor. **kwargs: Additional arguments to pass to the environment constructor.
@@ -953,14 +952,16 @@ def make_vec(
def _create_env(): def _create_env():
# Env creator for use with sync and async modes # Env creator for use with sync and async modes
render_mode = _kwargs.get("render_mode", None)
inner_render_mode = (
render_mode[: -len("_list")]
if render_mode is not None and render_mode.endswith("_list")
else render_mode
)
_kwargs_copy = _kwargs.copy() _kwargs_copy = _kwargs.copy()
_kwargs_copy["render_mode"] = inner_render_mode
render_mode = _kwargs.get("render_mode", None)
if render_mode is not None:
inner_render_mode = (
render_mode[: -len("_list")]
if render_mode.endswith("_list")
else render_mode
)
_kwargs_copy["render_mode"] = inner_render_mode
_env = env_creator(**_kwargs_copy) _env = env_creator(**_kwargs_copy)
_env.spec = spec_ _env.spec = spec_

View File

@@ -284,11 +284,18 @@ class VectorWrapper(VectorEnv):
# explicitly forward the methods defined in VectorEnv # explicitly forward the methods defined in VectorEnv
# to self.env (instead of the base class) # to self.env (instead of the base class)
def reset(self, **kwargs): def reset(
"""Reset all environments.""" self,
return self.env.reset(**kwargs) *,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Reset all environment using seed and options."""
return self.env.reset(seed=seed, options=options)
def step(self, actions): def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Step all environments.""" """Step all environments."""
return self.env.step(actions) return self.env.step(actions)
@@ -301,7 +308,7 @@ class VectorWrapper(VectorEnv):
return self.env.close_extras(**kwargs) return self.env.close_extras(**kwargs)
# implicitly forward all other methods and attributes to self.env # implicitly forward all other methods and attributes to self.env
def __getattr__(self, name): def __getattr__(self, name: str) -> Any:
"""Forward all other attributes to the base environment.""" """Forward all other attributes to the base environment."""
if name.startswith("_"): if name.startswith("_"):
raise AttributeError(f"attempted to get missing private attribute '{name}'") raise AttributeError(f"attempted to get missing private attribute '{name}'")
@@ -382,17 +389,23 @@ class VectorWrapper(VectorEnv):
class VectorObservationWrapper(VectorWrapper): class VectorObservationWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the observation. Equivalent to :class:`gym.ObservationWrapper` for vectorized environments.""" """Wraps the vectorized environment to allow a modular transformation of the observation. Equivalent to :class:`gym.ObservationWrapper` for vectorized environments."""
def reset(self, **kwargs): def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`.""" """Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
observation = self.env.reset(**kwargs) obs, info = self.env.reset(seed=seed, options=options)
return self.observation(observation) return self.observation(obs), info
def step(self, actions): def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`.""" """Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
observation, reward, termination, truncation, info = self.env.step(actions) observation, reward, termination, truncation, info = self.env.step(actions)
return ( return (
self.observation(observation), self.observation(observation),
observation,
reward, reward,
termination, termination,
truncation, truncation,
@@ -414,9 +427,11 @@ class VectorObservationWrapper(VectorWrapper):
class VectorActionWrapper(VectorWrapper): class VectorActionWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the actions. Equivalent of :class:`~gym.ActionWrapper` for vectorized environments.""" """Wraps the vectorized environment to allow a modular transformation of the actions. Equivalent of :class:`~gym.ActionWrapper` for vectorized environments."""
def step(self, actions: ActType): def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Steps through the environment using a modified action by :meth:`action`.""" """Steps through the environment using a modified action by :meth:`action`."""
return self.env.step(self.action(actions)) return self.env.step(self.actions(actions))
def actions(self, actions: ActType) -> ActType: def actions(self, actions: ActType) -> ActType:
"""Transform the actions before sending them to the environment. """Transform the actions before sending them to the environment.
@@ -433,7 +448,9 @@ class VectorActionWrapper(VectorWrapper):
class VectorRewardWrapper(VectorWrapper): class VectorRewardWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the reward. Equivalent of :class:`~gym.RewardWrapper` for vectorized environments.""" """Wraps the vectorized environment to allow a modular transformation of the reward. Equivalent of :class:`~gym.RewardWrapper` for vectorized environments."""
def step(self, actions): def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Steps through the environment returning a reward modified by :meth:`reward`.""" """Steps through the environment returning a reward modified by :meth:`reward`."""
observation, reward, termination, truncation, info = self.env.step(actions) observation, reward, termination, truncation, info = self.env.step(actions)
return observation, self.reward(reward), termination, truncation, info return observation, self.reward(reward), termination, truncation, info

View File

@@ -117,8 +117,10 @@ def __getattr__(wrapper_name: str):
AttributeError: If the wrapper does not exist. AttributeError: If the wrapper does not exist.
DeprecatedWrapper: If the version is not the latest. DeprecatedWrapper: If the version is not the latest.
""" """
if wrapper_name == "vector":
return importlib.import_module("gymnasium.experimental.wrappers.vector")
# Check if the requested wrapper is in the _wrapper_to_class dictionary # Check if the requested wrapper is in the _wrapper_to_class dictionary
if wrapper_name in _wrapper_to_class: elif wrapper_name in _wrapper_to_class:
import_stmt = ( import_stmt = (
f"gymnasium.experimental.wrappers.{_wrapper_to_class[wrapper_name]}" f"gymnasium.experimental.wrappers.{_wrapper_to_class[wrapper_name]}"
) )

View File

@@ -16,9 +16,7 @@ def test_flatten_observation_wrapper():
reset_func=record_random_obs_reset, reset_func=record_random_obs_reset,
step_func=record_random_obs_step, step_func=record_random_obs_step,
) )
print(env.observation_space)
wrapped_env = FlattenObservationV0(env) wrapped_env = FlattenObservationV0(env)
print(wrapped_env.observation_space)
obs, info = wrapped_env.reset() obs, info = wrapped_env.reset()
check_obs(env, wrapped_env, obs, info["obs"]) check_obs(env, wrapped_env, obs, info["obs"])

View File

@@ -4,7 +4,12 @@ import re
import pytest import pytest
import gymnasium
import gymnasium.experimental.wrappers as wrappers import gymnasium.experimental.wrappers as wrappers
from gymnasium.experimental.wrappers import (
_wrapper_to_class, # pyright: ignore[reportPrivateUsage]
)
from gymnasium.experimental.wrappers import __all__
def test_import_wrappers(): def test_import_wrappers():
@@ -41,3 +46,24 @@ def test_import_wrappers():
), ),
): ):
getattr(wrappers, "NonexistentWrapper") getattr(wrappers, "NonexistentWrapper")
def test_all_wrapper_shorten():
"""Test that all wrappers in `__all__` are contained within the `_wrapper_to_class` conversion."""
all_wrappers = set(__all__)
all_wrappers.remove("vector")
assert all_wrappers == set(_wrapper_to_class.keys())
@pytest.mark.parametrize("wrapper_name", __all__)
def test_all_wrappers_shortened(wrapper_name):
"""Check that each element of the `__all__` wrappers can be loaded, provided dependencies are installed."""
if wrapper_name != "vector":
try:
assert getattr(gymnasium.experimental.wrappers, wrapper_name) is not None
except gymnasium.error.DependencyNotInstalled as e:
pytest.skip(str(e))
def test_wrapper_vector():
assert gymnasium.experimental.wrappers.vector is not None

View File

@@ -1,26 +0,0 @@
"""Tests that all shortened imports for wrappers all work."""
import pytest
import gymnasium
from gymnasium.experimental.wrappers import (
_wrapper_to_class, # pyright: ignore[reportPrivateUsage]
)
from gymnasium.experimental.wrappers import __all__
def test_all_wrapper_shorten():
"""Test that all wrappers in `__all__` are contained within the `_wrapper_to_class` conversion."""
all_wrappers = set(__all__)
all_wrappers.remove("vector")
assert all_wrappers == set(_wrapper_to_class.keys())
@pytest.mark.parametrize("wrapper_name", __all__)
def test_all_wrappers_shortened(wrapper_name):
"""Check that each element of the `__all__` wrappers can be loaded, provided dependencies are installed."""
if wrapper_name != "vector":
try:
assert getattr(gymnasium.experimental.wrappers, wrapper_name) is not None
except gymnasium.error.DependencyNotInstalled as e:
pytest.skip(str(e))

View File

@@ -1,6 +1,10 @@
"""Test suite for ResizeObservationV0.""" """Test suite for ResizeObservationV0."""
import numpy as np from __future__ import annotations
import numpy as np
import pytest
import gymnasium as gym
from gymnasium.experimental.wrappers import ResizeObservationV0 from gymnasium.experimental.wrappers import ResizeObservationV0
from gymnasium.spaces import Box from gymnasium.spaces import Box
from tests.experimental.wrappers.utils import ( from tests.experimental.wrappers.utils import (
@@ -11,17 +15,42 @@ from tests.experimental.wrappers.utils import (
from tests.testing_env import GenericTestEnv from tests.testing_env import GenericTestEnv
def test_resize_observation_wrapper(): @pytest.mark.parametrize(
"env",
(
GenericTestEnv(
observation_space=Box(0, 255, shape=(60, 60, 3), dtype=np.uint8),
reset_func=record_random_obs_reset,
step_func=record_random_obs_step,
),
GenericTestEnv(
observation_space=Box(0, 255, shape=(60, 60), dtype=np.uint8),
reset_func=record_random_obs_reset,
step_func=record_random_obs_step,
),
),
)
def test_resize_observation_wrapper(env):
"""Test the ``ResizeObservation`` that the observation has changed size.""" """Test the ``ResizeObservation`` that the observation has changed size."""
env = GenericTestEnv(
observation_space=Box(0, 255, shape=(60, 60, 3), dtype=np.uint8),
reset_func=record_random_obs_reset,
step_func=record_random_obs_step,
)
wrapped_env = ResizeObservationV0(env, (25, 25)) wrapped_env = ResizeObservationV0(env, (25, 25))
assert wrapped_env.observation_space.shape[:2] == (25, 25)
obs, info = wrapped_env.reset() obs, info = wrapped_env.reset()
check_obs(env, wrapped_env, obs, info["obs"]) check_obs(env, wrapped_env, obs, info["obs"])
obs, _, _, _, info = wrapped_env.step(None) obs, _, _, _, info = wrapped_env.step(None)
check_obs(env, wrapped_env, obs, info["obs"]) check_obs(env, wrapped_env, obs, info["obs"])
@pytest.mark.parametrize("shape", ((10, 10), (20, 20), (60, 60), (100, 100)))
def test_resize_shapes(shape: tuple[int, int]):
env = ResizeObservationV0(gym.make("CarRacing-v2"), shape)
assert env.observation_space == Box(
low=0, high=255, shape=shape + (3,), dtype=np.uint8
)
obs, info = env.reset()
assert obs in env.observation_space
obs, _, _, _, _ = env.step(env.action_space.sample())
assert obs in env.observation_space

View File

@@ -382,7 +382,6 @@ def test_space_sample_mask(space: Space, mask, n_trials: int = 100):
expected_frequency = ( expected_frequency = (
np.ones(space.shape) * np.where(mask == 2, 0.5, mask) * n_trials np.ones(space.shape) * np.where(mask == 2, 0.5, mask) * n_trials
) )
print(expected_frequency)
observed_frequency = np.sum(samples, axis=0) observed_frequency = np.sum(samples, axis=0)
assert space.shape == expected_frequency.shape == observed_frequency.shape assert space.shape == expected_frequency.shape == observed_frequency.shape