Add deprecated wrapper error in gymnasium.experimental.wrappers (#341)

This commit is contained in:
Valentin
2023-04-16 13:04:55 +02:00
committed by GitHub
parent 3acaaeb7db
commit 30e846aec5
8 changed files with 117 additions and 21 deletions

View File

@@ -18,7 +18,7 @@ jobs:
--tag gymnasium-all-docker . --tag gymnasium-all-docker .
- name: Run tests - name: Run tests
run: docker run gymnasium-all-docker pytest tests/* run: docker run gymnasium-all-docker pytest tests/*
- name: Run doctest - name: Run doctests
run: docker run gymnasium-all-docker pytest --doctest-modules gymnasium/ run: docker run gymnasium-all-docker pytest --doctest-modules gymnasium/
build-necessary: build-necessary:

View File

@@ -101,7 +101,7 @@ We aimed to replace the wrappers in gymnasium v0.30.0 with these experimental wr
* - `supersuit.clip_reward_v0 <https://github.com/Farama-Foundation/SuperSuit/blob/314831a7d18e7254f455b181694c049908f95155/supersuit/generic_wrappers/basic_wrappers.py#L74>`_ * - `supersuit.clip_reward_v0 <https://github.com/Farama-Foundation/SuperSuit/blob/314831a7d18e7254f455b181694c049908f95155/supersuit/generic_wrappers/basic_wrappers.py#L74>`_
- :class:`experimental.wrappers.ClipRewardV0` - :class:`experimental.wrappers.ClipRewardV0`
* - :class:`wrappers.NormalizeReward` * - :class:`wrappers.NormalizeReward`
- :class:`experimental.wrappers.NormalizeRewardV0` - :class:`experimental.wrappers.NormalizeRewardV1`
``` ```
### Common Wrappers ### Common Wrappers

View File

@@ -37,7 +37,7 @@ title: Wrappers
```{eval-rst} ```{eval-rst}
.. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0 .. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0
.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0 .. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0
.. autoclass:: gymnasium.experimental.wrappers.NormalizeRewardV0 .. autoclass:: gymnasium.experimental.wrappers.NormalizeRewardV1
``` ```
## Other Wrappers ## Other Wrappers

View File

@@ -181,6 +181,10 @@ class RetriesExceededError(Error):
"""Error message for retries exceeding set number.""" """Error message for retries exceeding set number."""
class DeprecatedWrapper(ImportError):
"""Error message for importing an old version of a wrapper."""
# Vectorized environments errors # Vectorized environments errors

View File

@@ -1,7 +1,9 @@
"""`__init__` for experimental wrappers, to avoid loading the wrappers if unnecessary, we can hack python.""" """`__init__` for experimental wrappers, to avoid loading the wrappers if unnecessary, we can hack python."""
# pyright: reportUnsupportedDunderAll=false # pyright: reportUnsupportedDunderAll=false
import importlib import importlib
import re
from gymnasium.error import DeprecatedWrapper
__all__ = [ __all__ = [
@@ -30,7 +32,7 @@ __all__ = [
# --- Reward wrappers --- # --- Reward wrappers ---
"LambdaRewardV0", "LambdaRewardV0",
"ClipRewardV0", "ClipRewardV0",
"NormalizeRewardV0", "NormalizeRewardV1",
# --- Common --- # --- Common ---
"AutoresetV0", "AutoresetV0",
"PassiveEnvCheckerV0", "PassiveEnvCheckerV0",
@@ -66,7 +68,7 @@ _wrapper_to_class = {
# lambda_reward.py # lambda_reward.py
"ClipRewardV0": "lambda_reward", "ClipRewardV0": "lambda_reward",
"LambdaRewardV0": "lambda_reward", "LambdaRewardV0": "lambda_reward",
"NormalizeRewardV0": "lambda_reward", "NormalizeRewardV1": "lambda_reward",
# stateful_action # stateful_action
"StickyActionV0": "stateful_action", "StickyActionV0": "stateful_action",
# stateful_observation # stateful_observation
@@ -99,21 +101,64 @@ _wrapper_to_class = {
} }
def __getattr__(name: str): def __getattr__(wrapper_name: str):
"""To avoid having to load all wrappers on `import gymnasium` with all of their extra modules. """Load a wrapper by name.
This optimises the loading of gymnasium. This optimizes the loading of gymnasium wrappers by only loading the wrapper if it is used.
Errors will be raised if the wrapper does not exist or if the version is not the latest.
Args: Args:
name: The name of a wrapper to load wrapper_name: The name of a wrapper to load.
Returns: Returns:
Wrapper The specified wrapper.
Raises:
AttributeError: If the wrapper does not exist.
DeprecatedWrapper: If the version is not the latest.
""" """
if name in _wrapper_to_class: # Check if the requested wrapper is in the _wrapper_to_class dictionary
import_stmt = f"gymnasium.experimental.wrappers.{_wrapper_to_class[name]}" if wrapper_name in _wrapper_to_class:
import_stmt = (
f"gymnasium.experimental.wrappers.{_wrapper_to_class[wrapper_name]}"
)
module = importlib.import_module(import_stmt) module = importlib.import_module(import_stmt)
return getattr(module, name) return getattr(module, wrapper_name)
# add helpful error message if version number has changed
# Define a regex pattern to match the integer suffix (version number) of the wrapper
int_suffix_pattern = r"(\d+)$"
version_match = re.search(int_suffix_pattern, wrapper_name)
# If a version number is found, extract it and the base wrapper name
if version_match:
version = int(version_match.group())
base_name = wrapper_name[: -len(version_match.group())]
else: else:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}") version = float("inf")
base_name = wrapper_name[:-2]
# Filter the list of all wrappers to include only those with the same base name
matching_wrappers = [name for name in __all__ if name.startswith(base_name)]
# If no matching wrappers are found, raise an AttributeError
if not matching_wrappers:
raise AttributeError(f"module {__name__!r} has no attribute {wrapper_name!r}")
# Find the latest version of the matching wrappers
latest_wrapper = max(
matching_wrappers, key=lambda s: int(re.findall(int_suffix_pattern, s)[0])
)
latest_version = int(re.findall(int_suffix_pattern, latest_wrapper)[0])
# If the requested wrapper is an older version, raise a DeprecatedWrapper exception
if version < latest_version:
raise DeprecatedWrapper(
f"{wrapper_name!r} is now deprecated, use {latest_wrapper!r} instead.\n"
f"To see the changes made, go to "
f"https://gymnasium.farama.org/api/experimental/wrappers/#gymnasium.experimental.wrappers.{latest_wrapper}"
)
# If the requested version is invalid, raise an AttributeError
else:
raise AttributeError(
f"module {__name__!r} has no attribute {wrapper_name!r}, did you mean {latest_wrapper!r}"
)

View File

@@ -100,7 +100,7 @@ class ClipRewardV0(LambdaRewardV0[ObsType, ActType], gym.utils.RecordConstructor
) )
class NormalizeRewardV0( class NormalizeRewardV1(
gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs
): ):
r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
@@ -111,6 +111,10 @@ class NormalizeRewardV0(
statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.normalize()` is called. statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.normalize()` is called.
If False, the calculated statistics are used but not updated anymore; this may be used during evaluation. If False, the calculated statistics are used but not updated anymore; this may be used during evaluation.
Note:
In v0.27, NormalizeReward was updated as the forward discounted reward estimate was incorrect computed in Gym v0.25+.
For more detail, read [#3154](https://github.com/openai/gym/pull/3152).
Note: Note:
The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly
instantiated or the policy was changed recently. instantiated or the policy was changed recently.

View File

@@ -0,0 +1,43 @@
"""Test suite for import wrappers."""
import re
import pytest
import gymnasium.experimental.wrappers as wrappers
def test_import_wrappers():
"""Test that all wrappers can be imported."""
# Test that a deprecated wrapper raises a DeprecatedWrapper
with pytest.raises(
wrappers.DeprecatedWrapper,
match=re.escape("'NormalizeRewardV0' is now deprecated"),
):
getattr(wrappers, "NormalizeRewardV0")
# Test that an invalid version raises an AttributeError
with pytest.raises(
AttributeError,
match=re.escape(
"module 'gymnasium.experimental.wrappers' has no attribute 'ClipRewardVT', did you mean"
),
):
getattr(wrappers, "ClipRewardVT")
with pytest.raises(
AttributeError,
match=re.escape(
"module 'gymnasium.experimental.wrappers' has no attribute 'ClipRewardV99', did you mean"
),
):
getattr(wrappers, "ClipRewardV99")
# Test that an invalid wrapper raises an AttributeError
with pytest.raises(
AttributeError,
match=re.escape(
"module 'gymnasium.experimental.wrappers' has no attribute 'NonexistentWrapper'"
),
):
getattr(wrappers, "NonexistentWrapper")

View File

@@ -1,8 +1,8 @@
"""Test suite for NormalizeRewardV0.""" """Test suite for NormalizeRewardV1."""
import numpy as np import numpy as np
from gymnasium.core import ActType from gymnasium.core import ActType
from gymnasium.experimental.wrappers import NormalizeRewardV0 from gymnasium.experimental.wrappers import NormalizeRewardV1
from tests.testing_env import GenericTestEnv from tests.testing_env import GenericTestEnv
@@ -18,7 +18,7 @@ def _make_reward_env():
def test_running_mean_normalize_reward_wrapper(): def test_running_mean_normalize_reward_wrapper():
"""Tests that the property `_update_running_mean` freezes/continues the running statistics updating.""" """Tests that the property `_update_running_mean` freezes/continues the running statistics updating."""
env = _make_reward_env() env = _make_reward_env()
wrapped_env = NormalizeRewardV0(env) wrapped_env = NormalizeRewardV1(env)
# Default value is True # Default value is True
assert wrapped_env.update_running_mean assert wrapped_env.update_running_mean
@@ -48,7 +48,7 @@ def test_normalize_reward_wrapper():
"""Tests that the NormalizeReward does not throw an error.""" """Tests that the NormalizeReward does not throw an error."""
# TODO: Functional correctness should be tested # TODO: Functional correctness should be tested
env = _make_reward_env() env = _make_reward_env()
wrapped_env = NormalizeRewardV0(env) wrapped_env = NormalizeRewardV1(env)
wrapped_env.reset() wrapped_env.reset()
_, reward, _, _, _ = wrapped_env.step(None) _, reward, _, _, _ = wrapped_env.step(None)
assert np.ndim(reward) == 0 assert np.ndim(reward) == 0