mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-27 16:57:10 +00:00
Misc bug fixes (#516)
This commit is contained in:
@@ -904,7 +904,6 @@ def make_vec(
|
||||
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
|
||||
num_envs: Number of environments to create
|
||||
vectorization_mode: How to vectorize the environment. Can be either "async", "sync" or "custom"
|
||||
kwargs: Additional arguments to pass to the environment constructor.
|
||||
vector_kwargs: Additional arguments to pass to the vectorized environment constructor.
|
||||
wrappers: A sequence of wrapper functions to apply to the environment. Can only be used in "sync" or "async" mode.
|
||||
**kwargs: Additional arguments to pass to the environment constructor.
|
||||
@@ -953,13 +952,15 @@ def make_vec(
|
||||
|
||||
def _create_env():
|
||||
# Env creator for use with sync and async modes
|
||||
_kwargs_copy = _kwargs.copy()
|
||||
|
||||
render_mode = _kwargs.get("render_mode", None)
|
||||
if render_mode is not None:
|
||||
inner_render_mode = (
|
||||
render_mode[: -len("_list")]
|
||||
if render_mode is not None and render_mode.endswith("_list")
|
||||
if render_mode.endswith("_list")
|
||||
else render_mode
|
||||
)
|
||||
_kwargs_copy = _kwargs.copy()
|
||||
_kwargs_copy["render_mode"] = inner_render_mode
|
||||
|
||||
_env = env_creator(**_kwargs_copy)
|
||||
|
@@ -284,11 +284,18 @@ class VectorWrapper(VectorEnv):
|
||||
# explicitly forward the methods defined in VectorEnv
|
||||
# to self.env (instead of the base class)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
"""Reset all environments."""
|
||||
return self.env.reset(**kwargs)
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Reset all environment using seed and options."""
|
||||
return self.env.reset(seed=seed, options=options)
|
||||
|
||||
def step(self, actions):
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
||||
"""Step all environments."""
|
||||
return self.env.step(actions)
|
||||
|
||||
@@ -301,7 +308,7 @@ class VectorWrapper(VectorEnv):
|
||||
return self.env.close_extras(**kwargs)
|
||||
|
||||
# implicitly forward all other methods and attributes to self.env
|
||||
def __getattr__(self, name):
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Forward all other attributes to the base environment."""
|
||||
if name.startswith("_"):
|
||||
raise AttributeError(f"attempted to get missing private attribute '{name}'")
|
||||
@@ -382,17 +389,23 @@ class VectorWrapper(VectorEnv):
|
||||
class VectorObservationWrapper(VectorWrapper):
|
||||
"""Wraps the vectorized environment to allow a modular transformation of the observation. Equivalent to :class:`gym.ObservationWrapper` for vectorized environments."""
|
||||
|
||||
def reset(self, **kwargs):
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | list[int] | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
|
||||
observation = self.env.reset(**kwargs)
|
||||
return self.observation(observation)
|
||||
obs, info = self.env.reset(seed=seed, options=options)
|
||||
return self.observation(obs), info
|
||||
|
||||
def step(self, actions):
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
||||
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
|
||||
observation, reward, termination, truncation, info = self.env.step(actions)
|
||||
return (
|
||||
self.observation(observation),
|
||||
observation,
|
||||
reward,
|
||||
termination,
|
||||
truncation,
|
||||
@@ -414,9 +427,11 @@ class VectorObservationWrapper(VectorWrapper):
|
||||
class VectorActionWrapper(VectorWrapper):
|
||||
"""Wraps the vectorized environment to allow a modular transformation of the actions. Equivalent of :class:`~gym.ActionWrapper` for vectorized environments."""
|
||||
|
||||
def step(self, actions: ActType):
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
||||
"""Steps through the environment using a modified action by :meth:`action`."""
|
||||
return self.env.step(self.action(actions))
|
||||
return self.env.step(self.actions(actions))
|
||||
|
||||
def actions(self, actions: ActType) -> ActType:
|
||||
"""Transform the actions before sending them to the environment.
|
||||
@@ -433,7 +448,9 @@ class VectorActionWrapper(VectorWrapper):
|
||||
class VectorRewardWrapper(VectorWrapper):
|
||||
"""Wraps the vectorized environment to allow a modular transformation of the reward. Equivalent of :class:`~gym.RewardWrapper` for vectorized environments."""
|
||||
|
||||
def step(self, actions):
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
|
||||
"""Steps through the environment returning a reward modified by :meth:`reward`."""
|
||||
observation, reward, termination, truncation, info = self.env.step(actions)
|
||||
return observation, self.reward(reward), termination, truncation, info
|
||||
|
@@ -117,8 +117,10 @@ def __getattr__(wrapper_name: str):
|
||||
AttributeError: If the wrapper does not exist.
|
||||
DeprecatedWrapper: If the version is not the latest.
|
||||
"""
|
||||
if wrapper_name == "vector":
|
||||
return importlib.import_module("gymnasium.experimental.wrappers.vector")
|
||||
# Check if the requested wrapper is in the _wrapper_to_class dictionary
|
||||
if wrapper_name in _wrapper_to_class:
|
||||
elif wrapper_name in _wrapper_to_class:
|
||||
import_stmt = (
|
||||
f"gymnasium.experimental.wrappers.{_wrapper_to_class[wrapper_name]}"
|
||||
)
|
||||
|
@@ -16,9 +16,7 @@ def test_flatten_observation_wrapper():
|
||||
reset_func=record_random_obs_reset,
|
||||
step_func=record_random_obs_step,
|
||||
)
|
||||
print(env.observation_space)
|
||||
wrapped_env = FlattenObservationV0(env)
|
||||
print(wrapped_env.observation_space)
|
||||
|
||||
obs, info = wrapped_env.reset()
|
||||
check_obs(env, wrapped_env, obs, info["obs"])
|
||||
|
@@ -4,7 +4,12 @@ import re
|
||||
|
||||
import pytest
|
||||
|
||||
import gymnasium
|
||||
import gymnasium.experimental.wrappers as wrappers
|
||||
from gymnasium.experimental.wrappers import (
|
||||
_wrapper_to_class, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
from gymnasium.experimental.wrappers import __all__
|
||||
|
||||
|
||||
def test_import_wrappers():
|
||||
@@ -41,3 +46,24 @@ def test_import_wrappers():
|
||||
),
|
||||
):
|
||||
getattr(wrappers, "NonexistentWrapper")
|
||||
|
||||
|
||||
def test_all_wrapper_shorten():
|
||||
"""Test that all wrappers in `__all__` are contained within the `_wrapper_to_class` conversion."""
|
||||
all_wrappers = set(__all__)
|
||||
all_wrappers.remove("vector")
|
||||
assert all_wrappers == set(_wrapper_to_class.keys())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("wrapper_name", __all__)
|
||||
def test_all_wrappers_shortened(wrapper_name):
|
||||
"""Check that each element of the `__all__` wrappers can be loaded, provided dependencies are installed."""
|
||||
if wrapper_name != "vector":
|
||||
try:
|
||||
assert getattr(gymnasium.experimental.wrappers, wrapper_name) is not None
|
||||
except gymnasium.error.DependencyNotInstalled as e:
|
||||
pytest.skip(str(e))
|
||||
|
||||
|
||||
def test_wrapper_vector():
|
||||
assert gymnasium.experimental.wrappers.vector is not None
|
||||
|
@@ -1,26 +0,0 @@
|
||||
"""Tests that all shortened imports for wrappers all work."""
|
||||
|
||||
import pytest
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.experimental.wrappers import (
|
||||
_wrapper_to_class, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
from gymnasium.experimental.wrappers import __all__
|
||||
|
||||
|
||||
def test_all_wrapper_shorten():
|
||||
"""Test that all wrappers in `__all__` are contained within the `_wrapper_to_class` conversion."""
|
||||
all_wrappers = set(__all__)
|
||||
all_wrappers.remove("vector")
|
||||
assert all_wrappers == set(_wrapper_to_class.keys())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("wrapper_name", __all__)
|
||||
def test_all_wrappers_shortened(wrapper_name):
|
||||
"""Check that each element of the `__all__` wrappers can be loaded, provided dependencies are installed."""
|
||||
if wrapper_name != "vector":
|
||||
try:
|
||||
assert getattr(gymnasium.experimental.wrappers, wrapper_name) is not None
|
||||
except gymnasium.error.DependencyNotInstalled as e:
|
||||
pytest.skip(str(e))
|
@@ -1,6 +1,10 @@
|
||||
"""Test suite for ResizeObservationV0."""
|
||||
import numpy as np
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.experimental.wrappers import ResizeObservationV0
|
||||
from gymnasium.spaces import Box
|
||||
from tests.experimental.wrappers.utils import (
|
||||
@@ -11,17 +15,42 @@ from tests.experimental.wrappers.utils import (
|
||||
from tests.testing_env import GenericTestEnv
|
||||
|
||||
|
||||
def test_resize_observation_wrapper():
|
||||
"""Test the ``ResizeObservation`` that the observation has changed size."""
|
||||
env = GenericTestEnv(
|
||||
@pytest.mark.parametrize(
|
||||
"env",
|
||||
(
|
||||
GenericTestEnv(
|
||||
observation_space=Box(0, 255, shape=(60, 60, 3), dtype=np.uint8),
|
||||
reset_func=record_random_obs_reset,
|
||||
step_func=record_random_obs_step,
|
||||
)
|
||||
),
|
||||
GenericTestEnv(
|
||||
observation_space=Box(0, 255, shape=(60, 60), dtype=np.uint8),
|
||||
reset_func=record_random_obs_reset,
|
||||
step_func=record_random_obs_step,
|
||||
),
|
||||
),
|
||||
)
|
||||
def test_resize_observation_wrapper(env):
|
||||
"""Test the ``ResizeObservation`` that the observation has changed size."""
|
||||
|
||||
wrapped_env = ResizeObservationV0(env, (25, 25))
|
||||
assert wrapped_env.observation_space.shape[:2] == (25, 25)
|
||||
|
||||
obs, info = wrapped_env.reset()
|
||||
check_obs(env, wrapped_env, obs, info["obs"])
|
||||
|
||||
obs, _, _, _, info = wrapped_env.step(None)
|
||||
check_obs(env, wrapped_env, obs, info["obs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", ((10, 10), (20, 20), (60, 60), (100, 100)))
|
||||
def test_resize_shapes(shape: tuple[int, int]):
|
||||
env = ResizeObservationV0(gym.make("CarRacing-v2"), shape)
|
||||
assert env.observation_space == Box(
|
||||
low=0, high=255, shape=shape + (3,), dtype=np.uint8
|
||||
)
|
||||
|
||||
obs, info = env.reset()
|
||||
assert obs in env.observation_space
|
||||
obs, _, _, _, _ = env.step(env.action_space.sample())
|
||||
assert obs in env.observation_space
|
||||
|
@@ -382,7 +382,6 @@ def test_space_sample_mask(space: Space, mask, n_trials: int = 100):
|
||||
expected_frequency = (
|
||||
np.ones(space.shape) * np.where(mask == 2, 0.5, mask) * n_trials
|
||||
)
|
||||
print(expected_frequency)
|
||||
observed_frequency = np.sum(samples, axis=0)
|
||||
assert space.shape == expected_frequency.shape == observed_frequency.shape
|
||||
|
||||
|
Reference in New Issue
Block a user