mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 22:04:31 +00:00
Add deprecated wrapper error in gymnasium.experimental.wrappers (#341)
This commit is contained in:
2
.github/workflows/build.yml
vendored
2
.github/workflows/build.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
--tag gymnasium-all-docker .
|
||||
- name: Run tests
|
||||
run: docker run gymnasium-all-docker pytest tests/*
|
||||
- name: Run doctest
|
||||
- name: Run doctests
|
||||
run: docker run gymnasium-all-docker pytest --doctest-modules gymnasium/
|
||||
|
||||
build-necessary:
|
||||
|
@@ -101,7 +101,7 @@ We aimed to replace the wrappers in gymnasium v0.30.0 with these experimental wr
|
||||
* - `supersuit.clip_reward_v0 <https://github.com/Farama-Foundation/SuperSuit/blob/314831a7d18e7254f455b181694c049908f95155/supersuit/generic_wrappers/basic_wrappers.py#L74>`_
|
||||
- :class:`experimental.wrappers.ClipRewardV0`
|
||||
* - :class:`wrappers.NormalizeReward`
|
||||
- :class:`experimental.wrappers.NormalizeRewardV0`
|
||||
- :class:`experimental.wrappers.NormalizeRewardV1`
|
||||
```
|
||||
|
||||
### Common Wrappers
|
||||
|
@@ -37,7 +37,7 @@ title: Wrappers
|
||||
```{eval-rst}
|
||||
.. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.NormalizeRewardV0
|
||||
.. autoclass:: gymnasium.experimental.wrappers.NormalizeRewardV1
|
||||
```
|
||||
|
||||
## Other Wrappers
|
||||
|
@@ -181,6 +181,10 @@ class RetriesExceededError(Error):
|
||||
"""Error message for retries exceeding set number."""
|
||||
|
||||
|
||||
class DeprecatedWrapper(ImportError):
|
||||
"""Error message for importing an old version of a wrapper."""
|
||||
|
||||
|
||||
# Vectorized environments errors
|
||||
|
||||
|
||||
|
@@ -1,7 +1,9 @@
|
||||
"""`__init__` for experimental wrappers, to avoid loading the wrappers if unnecessary, we can hack python."""
|
||||
# pyright: reportUnsupportedDunderAll=false
|
||||
|
||||
import importlib
|
||||
import re
|
||||
|
||||
from gymnasium.error import DeprecatedWrapper
|
||||
|
||||
|
||||
__all__ = [
|
||||
@@ -30,7 +32,7 @@ __all__ = [
|
||||
# --- Reward wrappers ---
|
||||
"LambdaRewardV0",
|
||||
"ClipRewardV0",
|
||||
"NormalizeRewardV0",
|
||||
"NormalizeRewardV1",
|
||||
# --- Common ---
|
||||
"AutoresetV0",
|
||||
"PassiveEnvCheckerV0",
|
||||
@@ -66,7 +68,7 @@ _wrapper_to_class = {
|
||||
# lambda_reward.py
|
||||
"ClipRewardV0": "lambda_reward",
|
||||
"LambdaRewardV0": "lambda_reward",
|
||||
"NormalizeRewardV0": "lambda_reward",
|
||||
"NormalizeRewardV1": "lambda_reward",
|
||||
# stateful_action
|
||||
"StickyActionV0": "stateful_action",
|
||||
# stateful_observation
|
||||
@@ -99,21 +101,64 @@ _wrapper_to_class = {
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
"""To avoid having to load all wrappers on `import gymnasium` with all of their extra modules.
|
||||
def __getattr__(wrapper_name: str):
|
||||
"""Load a wrapper by name.
|
||||
|
||||
This optimises the loading of gymnasium.
|
||||
This optimizes the loading of gymnasium wrappers by only loading the wrapper if it is used.
|
||||
Errors will be raised if the wrapper does not exist or if the version is not the latest.
|
||||
|
||||
Args:
|
||||
name: The name of a wrapper to load
|
||||
wrapper_name: The name of a wrapper to load.
|
||||
|
||||
Returns:
|
||||
Wrapper
|
||||
The specified wrapper.
|
||||
|
||||
Raises:
|
||||
AttributeError: If the wrapper does not exist.
|
||||
DeprecatedWrapper: If the version is not the latest.
|
||||
"""
|
||||
if name in _wrapper_to_class:
|
||||
import_stmt = f"gymnasium.experimental.wrappers.{_wrapper_to_class[name]}"
|
||||
# Check if the requested wrapper is in the _wrapper_to_class dictionary
|
||||
if wrapper_name in _wrapper_to_class:
|
||||
import_stmt = (
|
||||
f"gymnasium.experimental.wrappers.{_wrapper_to_class[wrapper_name]}"
|
||||
)
|
||||
module = importlib.import_module(import_stmt)
|
||||
return getattr(module, name)
|
||||
# add helpful error message if version number has changed
|
||||
return getattr(module, wrapper_name)
|
||||
|
||||
# Define a regex pattern to match the integer suffix (version number) of the wrapper
|
||||
int_suffix_pattern = r"(\d+)$"
|
||||
version_match = re.search(int_suffix_pattern, wrapper_name)
|
||||
|
||||
# If a version number is found, extract it and the base wrapper name
|
||||
if version_match:
|
||||
version = int(version_match.group())
|
||||
base_name = wrapper_name[: -len(version_match.group())]
|
||||
else:
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
version = float("inf")
|
||||
base_name = wrapper_name[:-2]
|
||||
|
||||
# Filter the list of all wrappers to include only those with the same base name
|
||||
matching_wrappers = [name for name in __all__ if name.startswith(base_name)]
|
||||
|
||||
# If no matching wrappers are found, raise an AttributeError
|
||||
if not matching_wrappers:
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {wrapper_name!r}")
|
||||
|
||||
# Find the latest version of the matching wrappers
|
||||
latest_wrapper = max(
|
||||
matching_wrappers, key=lambda s: int(re.findall(int_suffix_pattern, s)[0])
|
||||
)
|
||||
latest_version = int(re.findall(int_suffix_pattern, latest_wrapper)[0])
|
||||
|
||||
# If the requested wrapper is an older version, raise a DeprecatedWrapper exception
|
||||
if version < latest_version:
|
||||
raise DeprecatedWrapper(
|
||||
f"{wrapper_name!r} is now deprecated, use {latest_wrapper!r} instead.\n"
|
||||
f"To see the changes made, go to "
|
||||
f"https://gymnasium.farama.org/api/experimental/wrappers/#gymnasium.experimental.wrappers.{latest_wrapper}"
|
||||
)
|
||||
# If the requested version is invalid, raise an AttributeError
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"module {__name__!r} has no attribute {wrapper_name!r}, did you mean {latest_wrapper!r}"
|
||||
)
|
||||
|
@@ -100,7 +100,7 @@ class ClipRewardV0(LambdaRewardV0[ObsType, ActType], gym.utils.RecordConstructor
|
||||
)
|
||||
|
||||
|
||||
class NormalizeRewardV0(
|
||||
class NormalizeRewardV1(
|
||||
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
|
||||
):
|
||||
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
||||
@@ -111,6 +111,10 @@ class NormalizeRewardV0(
|
||||
statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.normalize()` is called.
|
||||
If False, the calculated statistics are used but not updated anymore; this may be used during evaluation.
|
||||
|
||||
Note:
|
||||
In v0.27, NormalizeReward was updated as the forward discounted reward estimate was incorrect computed in Gym v0.25+.
|
||||
For more detail, read [#3154](https://github.com/openai/gym/pull/3152).
|
||||
|
||||
Note:
|
||||
The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly
|
||||
instantiated or the policy was changed recently.
|
||||
|
43
tests/experimental/wrappers/test_import_wrappers.py
Normal file
43
tests/experimental/wrappers/test_import_wrappers.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Test suite for import wrappers."""
|
||||
|
||||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
import gymnasium.experimental.wrappers as wrappers
|
||||
|
||||
|
||||
def test_import_wrappers():
|
||||
"""Test that all wrappers can be imported."""
|
||||
# Test that a deprecated wrapper raises a DeprecatedWrapper
|
||||
with pytest.raises(
|
||||
wrappers.DeprecatedWrapper,
|
||||
match=re.escape("'NormalizeRewardV0' is now deprecated"),
|
||||
):
|
||||
getattr(wrappers, "NormalizeRewardV0")
|
||||
|
||||
# Test that an invalid version raises an AttributeError
|
||||
with pytest.raises(
|
||||
AttributeError,
|
||||
match=re.escape(
|
||||
"module 'gymnasium.experimental.wrappers' has no attribute 'ClipRewardVT', did you mean"
|
||||
),
|
||||
):
|
||||
getattr(wrappers, "ClipRewardVT")
|
||||
|
||||
with pytest.raises(
|
||||
AttributeError,
|
||||
match=re.escape(
|
||||
"module 'gymnasium.experimental.wrappers' has no attribute 'ClipRewardV99', did you mean"
|
||||
),
|
||||
):
|
||||
getattr(wrappers, "ClipRewardV99")
|
||||
|
||||
# Test that an invalid wrapper raises an AttributeError
|
||||
with pytest.raises(
|
||||
AttributeError,
|
||||
match=re.escape(
|
||||
"module 'gymnasium.experimental.wrappers' has no attribute 'NonexistentWrapper'"
|
||||
),
|
||||
):
|
||||
getattr(wrappers, "NonexistentWrapper")
|
@@ -1,8 +1,8 @@
|
||||
"""Test suite for NormalizeRewardV0."""
|
||||
"""Test suite for NormalizeRewardV1."""
|
||||
import numpy as np
|
||||
|
||||
from gymnasium.core import ActType
|
||||
from gymnasium.experimental.wrappers import NormalizeRewardV0
|
||||
from gymnasium.experimental.wrappers import NormalizeRewardV1
|
||||
from tests.testing_env import GenericTestEnv
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ def _make_reward_env():
|
||||
def test_running_mean_normalize_reward_wrapper():
|
||||
"""Tests that the property `_update_running_mean` freezes/continues the running statistics updating."""
|
||||
env = _make_reward_env()
|
||||
wrapped_env = NormalizeRewardV0(env)
|
||||
wrapped_env = NormalizeRewardV1(env)
|
||||
|
||||
# Default value is True
|
||||
assert wrapped_env.update_running_mean
|
||||
@@ -48,7 +48,7 @@ def test_normalize_reward_wrapper():
|
||||
"""Tests that the NormalizeReward does not throw an error."""
|
||||
# TODO: Functional correctness should be tested
|
||||
env = _make_reward_env()
|
||||
wrapped_env = NormalizeRewardV0(env)
|
||||
wrapped_env = NormalizeRewardV1(env)
|
||||
wrapped_env.reset()
|
||||
_, reward, _, _, _ = wrapped_env.step(None)
|
||||
assert np.ndim(reward) == 0
|
||||
|
Reference in New Issue
Block a user