mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-26 16:27:11 +00:00
Support only new step API (while retaining compatibility functions) (#3019)
This commit is contained in:
@@ -31,9 +31,9 @@ observation, info = env.reset(seed=42)
|
|||||||
|
|
||||||
for _ in range(1000):
|
for _ in range(1000):
|
||||||
action = env.action_space.sample()
|
action = env.action_space.sample()
|
||||||
observation, reward, done, info = env.step(action)
|
observation, reward, terminated, truncarted, info = env.step(action)
|
||||||
|
|
||||||
if done:
|
if terminated or truncated:
|
||||||
observation, info = env.reset()
|
observation, info = env.reset()
|
||||||
env.close()
|
env.close()
|
||||||
```
|
```
|
||||||
|
46
gym/core.py
46
gym/core.py
@@ -16,7 +16,7 @@ from typing import (
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
from gym.logger import deprecation, warn
|
from gym.logger import warn
|
||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -83,16 +83,11 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
def np_random(self, value: np.random.Generator):
|
def np_random(self, value: np.random.Generator):
|
||||||
self._np_random = value
|
self._np_random = value
|
||||||
|
|
||||||
def step(
|
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
|
||||||
self, action: ActType
|
|
||||||
) -> Union[
|
|
||||||
Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict]
|
|
||||||
]:
|
|
||||||
"""Run one timestep of the environment's dynamics.
|
"""Run one timestep of the environment's dynamics.
|
||||||
|
|
||||||
When end of episode is reached, you are responsible for calling :meth:`reset` to reset this environment's state.
|
When end of episode is reached, you are responsible for calling :meth:`reset` to reset this environment's state.
|
||||||
Accepts an action and returns either a tuple `(observation, reward, terminated, truncated, info)`, or a tuple
|
Accepts an action and returns either a tuple `(observation, reward, terminated, truncated, info)`.
|
||||||
(observation, reward, done, info). The latter is deprecated and will be removed in future versions.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
action (ActType): an action provided by the agent
|
action (ActType): an action provided by the agent
|
||||||
@@ -226,12 +221,11 @@ class Wrapper(Env[ObsType, ActType]):
|
|||||||
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
|
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env: Env, new_step_api: bool = False):
|
def __init__(self, env: Env):
|
||||||
"""Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.
|
"""Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env: The environment to wrap
|
env: The environment to wrap
|
||||||
new_step_api: Whether the wrapper's step method will output in new or old step API
|
|
||||||
"""
|
"""
|
||||||
self.env = env
|
self.env = env
|
||||||
|
|
||||||
@@ -239,12 +233,6 @@ class Wrapper(Env[ObsType, ActType]):
|
|||||||
self._observation_space: Optional[spaces.Space] = None
|
self._observation_space: Optional[spaces.Space] = None
|
||||||
self._reward_range: Optional[Tuple[SupportsFloat, SupportsFloat]] = None
|
self._reward_range: Optional[Tuple[SupportsFloat, SupportsFloat]] = None
|
||||||
self._metadata: Optional[dict] = None
|
self._metadata: Optional[dict] = None
|
||||||
self.new_step_api = new_step_api
|
|
||||||
|
|
||||||
if not self.new_step_api:
|
|
||||||
deprecation(
|
|
||||||
"Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
|
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
|
||||||
@@ -326,17 +314,9 @@ class Wrapper(Env[ObsType, ActType]):
|
|||||||
"Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`."
|
"Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`."
|
||||||
)
|
)
|
||||||
|
|
||||||
def step(
|
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
|
||||||
self, action: ActType
|
|
||||||
) -> Union[
|
|
||||||
Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict]
|
|
||||||
]:
|
|
||||||
"""Steps through the environment with action."""
|
"""Steps through the environment with action."""
|
||||||
from gym.utils.step_api_compatibility import ( # avoid circular import
|
return self.env.step(action)
|
||||||
step_api_compatibility,
|
|
||||||
)
|
|
||||||
|
|
||||||
return step_api_compatibility(self.env.step(action), self.new_step_api)
|
|
||||||
|
|
||||||
def reset(self, **kwargs) -> Tuple[ObsType, dict]:
|
def reset(self, **kwargs) -> Tuple[ObsType, dict]:
|
||||||
"""Resets the environment with kwargs."""
|
"""Resets the environment with kwargs."""
|
||||||
@@ -401,13 +381,8 @@ class ObservationWrapper(Wrapper):
|
|||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""
|
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""
|
||||||
step_returns = self.env.step(action)
|
observation, reward, terminated, truncated, info = self.env.step(action)
|
||||||
if len(step_returns) == 5:
|
|
||||||
observation, reward, terminated, truncated, info = step_returns
|
|
||||||
return self.observation(observation), reward, terminated, truncated, info
|
return self.observation(observation), reward, terminated, truncated, info
|
||||||
else:
|
|
||||||
observation, reward, done, info = step_returns
|
|
||||||
return self.observation(observation), reward, done, info
|
|
||||||
|
|
||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
"""Returns a modified observation."""
|
"""Returns a modified observation."""
|
||||||
@@ -440,13 +415,8 @@ class RewardWrapper(Wrapper):
|
|||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""Modifies the reward using :meth:`self.reward` after the environment :meth:`env.step`."""
|
"""Modifies the reward using :meth:`self.reward` after the environment :meth:`env.step`."""
|
||||||
step_returns = self.env.step(action)
|
observation, reward, terminated, truncated, info = self.env.step(action)
|
||||||
if len(step_returns) == 5:
|
|
||||||
observation, reward, terminated, truncated, info = step_returns
|
|
||||||
return observation, self.reward(reward), terminated, truncated, info
|
return observation, self.reward(reward), terminated, truncated, info
|
||||||
else:
|
|
||||||
observation, reward, done, info = step_returns
|
|
||||||
return observation, self.reward(reward), done, info
|
|
||||||
|
|
||||||
def reward(self, reward):
|
def reward(self, reward):
|
||||||
"""Returns a modified ``reward``."""
|
"""Returns a modified ``reward``."""
|
||||||
|
@@ -140,7 +140,7 @@ class EnvSpec:
|
|||||||
order_enforce: bool = field(default=True)
|
order_enforce: bool = field(default=True)
|
||||||
autoreset: bool = field(default=False)
|
autoreset: bool = field(default=False)
|
||||||
disable_env_checker: bool = field(default=False)
|
disable_env_checker: bool = field(default=False)
|
||||||
new_step_api: bool = field(default=False)
|
apply_step_compatibility: bool = field(default=False)
|
||||||
|
|
||||||
# Environment arguments
|
# Environment arguments
|
||||||
kwargs: dict = field(default_factory=dict)
|
kwargs: dict = field(default_factory=dict)
|
||||||
@@ -547,7 +547,7 @@ def make(
|
|||||||
id: Union[str, EnvSpec],
|
id: Union[str, EnvSpec],
|
||||||
max_episode_steps: Optional[int] = None,
|
max_episode_steps: Optional[int] = None,
|
||||||
autoreset: bool = False,
|
autoreset: bool = False,
|
||||||
new_step_api: bool = False,
|
apply_step_compatibility: bool = False,
|
||||||
disable_env_checker: Optional[bool] = None,
|
disable_env_checker: Optional[bool] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Env:
|
) -> Env:
|
||||||
@@ -557,7 +557,7 @@ def make(
|
|||||||
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
|
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
|
||||||
max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
|
max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
|
||||||
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
|
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
|
||||||
new_step_api: Whether to use old or new step API (StepAPICompatibility wrapper). Will be removed at v1.0
|
apply_step_compatibility: Whether to use apply compatibility wrapper that converts step method to return two bools (StepAPICompatibility wrapper)
|
||||||
disable_env_checker: If to run the env checker, None will default to the environment specification `disable_env_checker`
|
disable_env_checker: If to run the env checker, None will default to the environment specification `disable_env_checker`
|
||||||
(which is by default False, running the environment checker),
|
(which is by default False, running the environment checker),
|
||||||
otherwise will run according to this parameter (`True` = not run, `False` = run)
|
otherwise will run according to this parameter (`True` = not run, `False` = run)
|
||||||
@@ -684,26 +684,28 @@ def make(
|
|||||||
):
|
):
|
||||||
env = PassiveEnvChecker(env)
|
env = PassiveEnvChecker(env)
|
||||||
|
|
||||||
env = StepAPICompatibility(env, new_step_api)
|
|
||||||
|
|
||||||
# Add the order enforcing wrapper
|
# Add the order enforcing wrapper
|
||||||
if spec_.order_enforce:
|
if spec_.order_enforce:
|
||||||
env = OrderEnforcing(env)
|
env = OrderEnforcing(env)
|
||||||
|
|
||||||
# Add the time limit wrapper
|
# Add the time limit wrapper
|
||||||
if max_episode_steps is not None:
|
if max_episode_steps is not None:
|
||||||
env = TimeLimit(env, max_episode_steps, new_step_api)
|
env = TimeLimit(env, max_episode_steps)
|
||||||
elif spec_.max_episode_steps is not None:
|
elif spec_.max_episode_steps is not None:
|
||||||
env = TimeLimit(env, spec_.max_episode_steps, new_step_api)
|
env = TimeLimit(env, spec_.max_episode_steps)
|
||||||
|
|
||||||
# Add the autoreset wrapper
|
# Add the autoreset wrapper
|
||||||
if autoreset:
|
if autoreset:
|
||||||
env = AutoResetWrapper(env, new_step_api)
|
env = AutoResetWrapper(env)
|
||||||
|
|
||||||
# Add human rendering wrapper
|
# Add human rendering wrapper
|
||||||
if apply_human_rendering:
|
if apply_human_rendering:
|
||||||
env = HumanRendering(env)
|
env = HumanRendering(env)
|
||||||
|
|
||||||
|
# Add step API wrapper
|
||||||
|
if apply_step_compatibility:
|
||||||
|
env = StepAPICompatibility(env, True)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
@@ -195,7 +195,11 @@ def env_reset_passive_checker(env, **kwargs):
|
|||||||
logger.warn(
|
logger.warn(
|
||||||
f"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `{type(result)}`"
|
f"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `{type(result)}`"
|
||||||
)
|
)
|
||||||
|
elif len(result) != 2:
|
||||||
|
logger.warn(
|
||||||
|
"The result returned by `env.reset()` should be `(obs, info)` by default, , where `obs` is a observation and `info` is a dictionary containing additional information."
|
||||||
|
)
|
||||||
|
else:
|
||||||
obs, info = result
|
obs, info = result
|
||||||
check_obs(obs, env.observation_space, "reset")
|
check_obs(obs, env.observation_space, "reset")
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
|
@@ -170,7 +170,7 @@ def play(
|
|||||||
:class:`gym.utils.play.PlayPlot`. Here's a sample code for plotting the reward
|
:class:`gym.utils.play.PlayPlot`. Here's a sample code for plotting the reward
|
||||||
for last 150 steps.
|
for last 150 steps.
|
||||||
|
|
||||||
>>> def callback(obs_t, obs_tp1, action, rew, done, info):
|
>>> def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info):
|
||||||
... return [rew,]
|
... return [rew,]
|
||||||
>>> plotter = PlayPlot(callback, 150, ["reward"])
|
>>> plotter = PlayPlot(callback, 150, ["reward"])
|
||||||
>>> play(gym.make("ALE/AirRaid-v5"), callback=plotter.callback)
|
>>> play(gym.make("ALE/AirRaid-v5"), callback=plotter.callback)
|
||||||
@@ -187,7 +187,8 @@ def play(
|
|||||||
obs_tp1: observation after performing action
|
obs_tp1: observation after performing action
|
||||||
action: action that was executed
|
action: action that was executed
|
||||||
rew: reward that was received
|
rew: reward that was received
|
||||||
done: whether the environment is done or not
|
terminated: whether the environment is terminated or not
|
||||||
|
truncated: whether the environment is truncated or not
|
||||||
info: debug info
|
info: debug info
|
||||||
keys_to_action: Mapping from keys pressed to action performed.
|
keys_to_action: Mapping from keys pressed to action performed.
|
||||||
Different formats are supported: Key combinations can either be expressed as a tuple of unicode code
|
Different formats are supported: Key combinations can either be expressed as a tuple of unicode code
|
||||||
@@ -219,11 +220,6 @@ def play(
|
|||||||
deprecation(
|
deprecation(
|
||||||
"`play.py` currently supports only the old step API which returns one boolean, however this will soon be updated to support only the new step api that returns two bools."
|
"`play.py` currently supports only the old step API which returns one boolean, however this will soon be updated to support only the new step api that returns two bools."
|
||||||
)
|
)
|
||||||
if env.render_mode not in {"rgb_array", "single_rgb_array"}:
|
|
||||||
logger.error(
|
|
||||||
"play method works only with rgb_array and single_rgb_array render modes, "
|
|
||||||
f"but your environment render_mode = {env.render_mode}."
|
|
||||||
)
|
|
||||||
|
|
||||||
env.reset(seed=seed)
|
env.reset(seed=seed)
|
||||||
|
|
||||||
@@ -261,9 +257,10 @@ def play(
|
|||||||
else:
|
else:
|
||||||
action = key_code_to_action.get(tuple(sorted(game.pressed_keys)), noop)
|
action = key_code_to_action.get(tuple(sorted(game.pressed_keys)), noop)
|
||||||
prev_obs = obs
|
prev_obs = obs
|
||||||
obs, rew, done, info = env.step(action)
|
obs, rew, terminated, truncated, info = env.step(action)
|
||||||
|
done = terminated or truncated
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback(prev_obs, obs, action, rew, done, info)
|
callback(prev_obs, obs, action, rew, terminated, truncated, info)
|
||||||
if obs is not None:
|
if obs is not None:
|
||||||
rendered = env.render()
|
rendered = env.render()
|
||||||
if isinstance(rendered, List):
|
if isinstance(rendered, List):
|
||||||
@@ -290,13 +287,14 @@ class PlayPlot:
|
|||||||
- obs_tp1: observation after performing action
|
- obs_tp1: observation after performing action
|
||||||
- action: action that was executed
|
- action: action that was executed
|
||||||
- rew: reward that was received
|
- rew: reward that was received
|
||||||
- done: whether the environment is done or not
|
- terminated: whether the environment is terminated or not
|
||||||
|
- truncated: whether the environment is truncated or not
|
||||||
- info: debug info
|
- info: debug info
|
||||||
|
|
||||||
It should return a list of metrics that are computed from this data.
|
It should return a list of metrics that are computed from this data.
|
||||||
For instance, the function may look like this::
|
For instance, the function may look like this::
|
||||||
|
|
||||||
>>> def compute_metrics(obs_t, obs_tp, action, reward, done, info):
|
>>> def compute_metrics(obs_t, obs_tp, action, reward, terminated, truncated, info):
|
||||||
... return [reward, info["cumulative_reward"], np.linalg.norm(action)]
|
... return [reward, info["cumulative_reward"], np.linalg.norm(action)]
|
||||||
|
|
||||||
:class:`PlayPlot` provides the method :meth:`callback` which will pass its arguments along to that function
|
:class:`PlayPlot` provides the method :meth:`callback` which will pass its arguments along to that function
|
||||||
@@ -353,7 +351,8 @@ class PlayPlot:
|
|||||||
obs_tp1: ObsType,
|
obs_tp1: ObsType,
|
||||||
action: ActType,
|
action: ActType,
|
||||||
rew: float,
|
rew: float,
|
||||||
done: bool,
|
terminated: bool,
|
||||||
|
truncated: bool,
|
||||||
info: dict,
|
info: dict,
|
||||||
):
|
):
|
||||||
"""The callback that calls the provided data callback and adds the data to the plots.
|
"""The callback that calls the provided data callback and adds the data to the plots.
|
||||||
@@ -363,10 +362,13 @@ class PlayPlot:
|
|||||||
obs_tp1: The observation at time step t+1
|
obs_tp1: The observation at time step t+1
|
||||||
action: The action
|
action: The action
|
||||||
rew: The reward
|
rew: The reward
|
||||||
done: If the environment is done
|
terminated: If the environment is terminated
|
||||||
|
truncated: If the environment is truncated
|
||||||
info: The information from the environment
|
info: The information from the environment
|
||||||
"""
|
"""
|
||||||
points = self.data_callback(obs_t, obs_tp1, action, rew, done, info)
|
points = self.data_callback(
|
||||||
|
obs_t, obs_tp1, action, rew, terminated, truncated, info
|
||||||
|
)
|
||||||
for point, data_series in zip(points, self.data):
|
for point, data_series in zip(points, self.data):
|
||||||
data_series.append(point)
|
data_series.append(point)
|
||||||
self.t += 1
|
self.t += 1
|
||||||
|
@@ -1,18 +1,18 @@
|
|||||||
"""Contains methods for step compatibility, from old-to-new and new-to-old API, to be removed in 1.0."""
|
"""Contains methods for step compatibility, from old-to-new and new-to-old API."""
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from gym.core import ObsType
|
from gym.core import ObsType
|
||||||
|
|
||||||
OldStepType = Tuple[
|
DoneStepType = Tuple[
|
||||||
Union[ObsType, np.ndarray],
|
Union[ObsType, np.ndarray],
|
||||||
Union[float, np.ndarray],
|
Union[float, np.ndarray],
|
||||||
Union[bool, np.ndarray],
|
Union[bool, np.ndarray],
|
||||||
Union[dict, list],
|
Union[dict, list],
|
||||||
]
|
]
|
||||||
|
|
||||||
NewStepType = Tuple[
|
TerminatedTruncatedStepType = Tuple[
|
||||||
Union[ObsType, np.ndarray],
|
Union[ObsType, np.ndarray],
|
||||||
Union[float, np.ndarray],
|
Union[float, np.ndarray],
|
||||||
Union[bool, np.ndarray],
|
Union[bool, np.ndarray],
|
||||||
@@ -21,9 +21,9 @@ NewStepType = Tuple[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def step_to_new_api(
|
def convert_to_terminated_truncated_step_api(
|
||||||
step_returns: Union[OldStepType, NewStepType], is_vector_env=False
|
step_returns: Union[DoneStepType, TerminatedTruncatedStepType], is_vector_env=False
|
||||||
) -> NewStepType:
|
) -> TerminatedTruncatedStepType:
|
||||||
"""Function to transform step returns to new step API irrespective of input API.
|
"""Function to transform step returns to new step API irrespective of input API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -73,9 +73,10 @@ def step_to_new_api(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def step_to_old_api(
|
def convert_to_done_step_api(
|
||||||
step_returns: Union[NewStepType, OldStepType], is_vector_env: bool = False
|
step_returns: Union[TerminatedTruncatedStepType, DoneStepType],
|
||||||
) -> OldStepType:
|
is_vector_env: bool = False,
|
||||||
|
) -> DoneStepType:
|
||||||
"""Function to transform step returns to old step API irrespective of input API.
|
"""Function to transform step returns to old step API irrespective of input API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -128,33 +129,33 @@ def step_to_old_api(
|
|||||||
|
|
||||||
|
|
||||||
def step_api_compatibility(
|
def step_api_compatibility(
|
||||||
step_returns: Union[NewStepType, OldStepType],
|
step_returns: Union[TerminatedTruncatedStepType, DoneStepType],
|
||||||
new_step_api: bool = False,
|
output_truncation_bool: bool = True,
|
||||||
is_vector_env: bool = False,
|
is_vector_env: bool = False,
|
||||||
) -> Union[NewStepType, OldStepType]:
|
) -> Union[TerminatedTruncatedStepType, DoneStepType]:
|
||||||
"""Function to transform step returns to the API specified by `new_step_api` bool.
|
"""Function to transform step returns to the API specified by `output_truncation_bool` bool.
|
||||||
|
|
||||||
Old step API refers to step() method returning (observation, reward, done, info)
|
Done (old) step API refers to step() method returning (observation, reward, done, info)
|
||||||
New step API refers to step() method returning (observation, reward, terminated, truncated, info)
|
Terminated Truncated (new) step API refers to step() method returning (observation, reward, terminated, truncated, info)
|
||||||
(Refer to docs for details on the API change)
|
(Refer to docs for details on the API change)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info)
|
step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info)
|
||||||
new_step_api (bool): Whether the output should be in new step API or old (False by default)
|
output_truncation_bool (bool): Whether the output should return two booleans (new API) or one (old) (True by default)
|
||||||
is_vector_env (bool): Whether the step_returns are from a vector environment
|
is_vector_env (bool): Whether the step_returns are from a vector environment
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
step_returns (tuple): Depending on `new_step_api` bool, it can return (obs, rew, done, info) or (obs, rew, terminated, truncated, info)
|
step_returns (tuple): Depending on `output_truncation_bool` bool, it can return (obs, rew, done, info) or (obs, rew, terminated, truncated, info)
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
This function can be used to ensure compatibility in step interfaces with conflicting API. Eg. if env is written in old API,
|
This function can be used to ensure compatibility in step interfaces with conflicting API. Eg. if env is written in old API,
|
||||||
wrapper is written in new API, and the final step output is desired to be in old API.
|
wrapper is written in new API, and the final step output is desired to be in old API.
|
||||||
|
|
||||||
>>> obs, rew, done, info = step_api_compatibility(env.step(action))
|
>>> obs, rew, done, info = step_api_compatibility(env.step(action), output_truncation_bool=False)
|
||||||
>>> obs, rew, terminated, truncated, info = step_api_compatibility(env.step(action), new_step_api=True)
|
>>> obs, rew, terminated, truncated, info = step_api_compatibility(env.step(action), output_truncation_bool=True)
|
||||||
>>> observations, rewards, dones, infos = step_api_compatibility(vec_env.step(action), is_vector_env=True)
|
>>> observations, rewards, dones, infos = step_api_compatibility(vec_env.step(action), is_vector_env=True)
|
||||||
"""
|
"""
|
||||||
if new_step_api:
|
if output_truncation_bool:
|
||||||
return step_to_new_api(step_returns, is_vector_env)
|
return convert_to_terminated_truncated_step_api(step_returns, is_vector_env)
|
||||||
else:
|
else:
|
||||||
return step_to_old_api(step_returns, is_vector_env)
|
return convert_to_done_step_api(step_returns, is_vector_env)
|
||||||
|
@@ -15,7 +15,6 @@ def make(
|
|||||||
asynchronous: bool = True,
|
asynchronous: bool = True,
|
||||||
wrappers: Optional[Union[callable, List[callable]]] = None,
|
wrappers: Optional[Union[callable, List[callable]]] = None,
|
||||||
disable_env_checker: Optional[bool] = None,
|
disable_env_checker: Optional[bool] = None,
|
||||||
new_step_api: bool = False,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> VectorEnv:
|
) -> VectorEnv:
|
||||||
"""Create a vectorized environment from multiple copies of an environment, from its id.
|
"""Create a vectorized environment from multiple copies of an environment, from its id.
|
||||||
@@ -37,7 +36,6 @@ def make(
|
|||||||
wrappers: If not ``None``, then apply the wrappers to each internal environment during creation.
|
wrappers: If not ``None``, then apply the wrappers to each internal environment during creation.
|
||||||
disable_env_checker: If to run the env checker for the first environment only. None will default to the environment spec `disable_env_checker` parameter
|
disable_env_checker: If to run the env checker for the first environment only. None will default to the environment spec `disable_env_checker` parameter
|
||||||
(that is by default False), otherwise will run according to this argument (True = not run, False = run)
|
(that is by default False), otherwise will run according to this argument (True = not run, False = run)
|
||||||
new_step_api: If True, the vector environment's step method outputs two booleans `terminated`, `truncated` instead of one `done`.
|
|
||||||
**kwargs: Keywords arguments applied during `gym.make`
|
**kwargs: Keywords arguments applied during `gym.make`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -53,7 +51,6 @@ def make(
|
|||||||
env = gym.envs.registration.make(
|
env = gym.envs.registration.make(
|
||||||
id,
|
id,
|
||||||
disable_env_checker=_disable_env_checker,
|
disable_env_checker=_disable_env_checker,
|
||||||
new_step_api=True,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
if wrappers is not None:
|
if wrappers is not None:
|
||||||
@@ -73,8 +70,4 @@ def make(
|
|||||||
env_fns = [
|
env_fns = [
|
||||||
create_env(disable_env_checker or env_num > 0) for env_num in range(num_envs)
|
create_env(disable_env_checker or env_num > 0) for env_num in range(num_envs)
|
||||||
]
|
]
|
||||||
return (
|
return AsyncVectorEnv(env_fns) if asynchronous else SyncVectorEnv(env_fns)
|
||||||
AsyncVectorEnv(env_fns, new_step_api=new_step_api)
|
|
||||||
if asynchronous
|
|
||||||
else SyncVectorEnv(env_fns, new_step_api=new_step_api)
|
|
||||||
)
|
|
||||||
|
@@ -17,7 +17,6 @@ from gym.error import (
|
|||||||
CustomSpaceError,
|
CustomSpaceError,
|
||||||
NoAsyncCallError,
|
NoAsyncCallError,
|
||||||
)
|
)
|
||||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
|
||||||
from gym.vector.utils import (
|
from gym.vector.utils import (
|
||||||
CloudpickleWrapper,
|
CloudpickleWrapper,
|
||||||
clear_mpi_env_vars,
|
clear_mpi_env_vars,
|
||||||
@@ -67,7 +66,6 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
context: Optional[str] = None,
|
context: Optional[str] = None,
|
||||||
daemon: bool = True,
|
daemon: bool = True,
|
||||||
worker: Optional[callable] = None,
|
worker: Optional[callable] = None,
|
||||||
new_step_api: bool = False,
|
|
||||||
):
|
):
|
||||||
"""Vectorized environment that runs multiple environments in parallel.
|
"""Vectorized environment that runs multiple environments in parallel.
|
||||||
|
|
||||||
@@ -87,7 +85,6 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
so for some environments you may want to have it set to ``False``.
|
so for some environments you may want to have it set to ``False``.
|
||||||
worker: If set, then use that worker in a subprocess instead of a default one.
|
worker: If set, then use that worker in a subprocess instead of a default one.
|
||||||
Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled.
|
Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled.
|
||||||
new_step_api: If True, step method returns 2 bools - terminated, truncated, instead of 1 bool - done
|
|
||||||
|
|
||||||
Warnings: worker is an advanced mode option. It provides a high degree of flexibility and a high chance
|
Warnings: worker is an advanced mode option. It provides a high degree of flexibility and a high chance
|
||||||
to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start
|
to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start
|
||||||
@@ -115,7 +112,6 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
num_envs=len(env_fns),
|
num_envs=len(env_fns),
|
||||||
observation_space=observation_space,
|
observation_space=observation_space,
|
||||||
action_space=action_space,
|
action_space=action_space,
|
||||||
new_step_api=new_step_api,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.shared_memory:
|
if self.shared_memory:
|
||||||
@@ -291,14 +287,14 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
|
|
||||||
def step_wait(
|
def step_wait(
|
||||||
self, timeout: Optional[Union[int, float]] = None
|
self, timeout: Optional[Union[int, float]] = None
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[dict]]:
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[dict]]:
|
||||||
"""Wait for the calls to :obj:`step` in each sub-environment to finish.
|
"""Wait for the calls to :obj:`step` in each sub-environment to finish.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out.
|
timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The batched environment step information, (obs, reward, terminated, truncated, info) or (obs, reward, done, info) depending on new_step_api
|
The batched environment step information, (obs, reward, terminated, truncated, info)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
|
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
|
||||||
@@ -322,7 +318,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
successes = []
|
successes = []
|
||||||
for i, pipe in enumerate(self.parent_pipes):
|
for i, pipe in enumerate(self.parent_pipes):
|
||||||
result, success = pipe.recv()
|
result, success = pipe.recv()
|
||||||
obs, rew, terminated, truncated, info = step_api_compatibility(result, True)
|
obs, rew, terminated, truncated, info = result
|
||||||
|
|
||||||
successes.append(success)
|
successes.append(success)
|
||||||
observations_list.append(obs)
|
observations_list.append(obs)
|
||||||
@@ -341,16 +337,12 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
self.observations,
|
self.observations,
|
||||||
)
|
)
|
||||||
|
|
||||||
return step_api_compatibility(
|
return (
|
||||||
(
|
|
||||||
deepcopy(self.observations) if self.copy else self.observations,
|
deepcopy(self.observations) if self.copy else self.observations,
|
||||||
np.array(rewards),
|
np.array(rewards),
|
||||||
np.array(terminateds, dtype=np.bool_),
|
np.array(terminateds, dtype=np.bool_),
|
||||||
np.array(truncateds, dtype=np.bool_),
|
np.array(truncateds, dtype=np.bool_),
|
||||||
infos,
|
infos,
|
||||||
),
|
|
||||||
self.new_step_api,
|
|
||||||
True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def call_async(self, name: str, *args, **kwargs):
|
def call_async(self, name: str, *args, **kwargs):
|
||||||
@@ -572,7 +564,7 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
|||||||
terminated,
|
terminated,
|
||||||
truncated,
|
truncated,
|
||||||
info,
|
info,
|
||||||
) = step_api_compatibility(env.step(data), True)
|
) = env.step(data)
|
||||||
if terminated or truncated:
|
if terminated or truncated:
|
||||||
old_observation = observation
|
old_observation = observation
|
||||||
observation, info = env.reset()
|
observation, info = env.reset()
|
||||||
@@ -642,7 +634,7 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
|
|||||||
terminated,
|
terminated,
|
||||||
truncated,
|
truncated,
|
||||||
info,
|
info,
|
||||||
) = step_api_compatibility(env.step(data), True)
|
) = env.step(data)
|
||||||
if terminated or truncated:
|
if terminated or truncated:
|
||||||
old_observation = observation
|
old_observation = observation
|
||||||
observation, info = env.reset()
|
observation, info = env.reset()
|
||||||
|
@@ -6,7 +6,6 @@ import numpy as np
|
|||||||
|
|
||||||
from gym import Env
|
from gym import Env
|
||||||
from gym.spaces import Space
|
from gym.spaces import Space
|
||||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
|
||||||
from gym.vector.utils import concatenate, create_empty_array, iterate
|
from gym.vector.utils import concatenate, create_empty_array, iterate
|
||||||
from gym.vector.vector_env import VectorEnv
|
from gym.vector.vector_env import VectorEnv
|
||||||
|
|
||||||
@@ -34,7 +33,6 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
observation_space: Space = None,
|
observation_space: Space = None,
|
||||||
action_space: Space = None,
|
action_space: Space = None,
|
||||||
copy: bool = True,
|
copy: bool = True,
|
||||||
new_step_api: bool = False,
|
|
||||||
):
|
):
|
||||||
"""Vectorized environment that serially runs multiple environments.
|
"""Vectorized environment that serially runs multiple environments.
|
||||||
|
|
||||||
@@ -62,7 +60,6 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
num_envs=len(self.envs),
|
num_envs=len(self.envs),
|
||||||
observation_space=observation_space,
|
observation_space=observation_space,
|
||||||
action_space=action_space,
|
action_space=action_space,
|
||||||
new_step_api=new_step_api,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._check_spaces()
|
self._check_spaces()
|
||||||
@@ -143,13 +140,15 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
"""
|
"""
|
||||||
observations, infos = [], {}
|
observations, infos = [], {}
|
||||||
for i, (env, action) in enumerate(zip(self.envs, self._actions)):
|
for i, (env, action) in enumerate(zip(self.envs, self._actions)):
|
||||||
|
|
||||||
(
|
(
|
||||||
observation,
|
observation,
|
||||||
self._rewards[i],
|
self._rewards[i],
|
||||||
self._terminateds[i],
|
self._terminateds[i],
|
||||||
self._truncateds[i],
|
self._truncateds[i],
|
||||||
info,
|
info,
|
||||||
) = step_api_compatibility(env.step(action), True)
|
) = env.step(action)
|
||||||
|
|
||||||
if self._terminateds[i] or self._truncateds[i]:
|
if self._terminateds[i] or self._truncateds[i]:
|
||||||
old_observation = observation
|
old_observation = observation
|
||||||
observation, info = env.reset()
|
observation, info = env.reset()
|
||||||
@@ -160,16 +159,12 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
self.single_observation_space, observations, self.observations
|
self.single_observation_space, observations, self.observations
|
||||||
)
|
)
|
||||||
|
|
||||||
return step_api_compatibility(
|
return (
|
||||||
(
|
|
||||||
deepcopy(self.observations) if self.copy else self.observations,
|
deepcopy(self.observations) if self.copy else self.observations,
|
||||||
np.copy(self._rewards),
|
np.copy(self._rewards),
|
||||||
np.copy(self._terminateds),
|
np.copy(self._terminateds),
|
||||||
np.copy(self._truncateds),
|
np.copy(self._truncateds),
|
||||||
infos,
|
infos,
|
||||||
),
|
|
||||||
new_step_api=self.new_step_api,
|
|
||||||
is_vector_env=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def call(self, name, *args, **kwargs) -> tuple:
|
def call(self, name, *args, **kwargs) -> tuple:
|
||||||
|
@@ -4,7 +4,6 @@ from typing import Any, List, Optional, Tuple, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym.logger import deprecation
|
|
||||||
from gym.vector.utils.spaces import batch_space
|
from gym.vector.utils.spaces import batch_space
|
||||||
|
|
||||||
__all__ = ["VectorEnv"]
|
__all__ = ["VectorEnv"]
|
||||||
@@ -28,7 +27,6 @@ class VectorEnv(gym.Env):
|
|||||||
num_envs: int,
|
num_envs: int,
|
||||||
observation_space: gym.Space,
|
observation_space: gym.Space,
|
||||||
action_space: gym.Space,
|
action_space: gym.Space,
|
||||||
new_step_api: bool = False,
|
|
||||||
):
|
):
|
||||||
"""Base class for vectorized environments.
|
"""Base class for vectorized environments.
|
||||||
|
|
||||||
@@ -36,7 +34,6 @@ class VectorEnv(gym.Env):
|
|||||||
num_envs: Number of environments in the vectorized environment.
|
num_envs: Number of environments in the vectorized environment.
|
||||||
observation_space: Observation space of a single environment.
|
observation_space: Observation space of a single environment.
|
||||||
action_space: Action space of a single environment.
|
action_space: Action space of a single environment.
|
||||||
new_step_api (bool): Whether the vector environment's step method outputs two boolean arrays (new API) or one boolean array (old API)
|
|
||||||
"""
|
"""
|
||||||
self.num_envs = num_envs
|
self.num_envs = num_envs
|
||||||
self.is_vector_env = True
|
self.is_vector_env = True
|
||||||
@@ -51,12 +48,6 @@ class VectorEnv(gym.Env):
|
|||||||
self.single_observation_space = observation_space
|
self.single_observation_space = observation_space
|
||||||
self.single_action_space = action_space
|
self.single_action_space = action_space
|
||||||
|
|
||||||
self.new_step_api = new_step_api
|
|
||||||
if not self.new_step_api:
|
|
||||||
deprecation(
|
|
||||||
"Initializing vector env in old step API which returns one bool array instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future."
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset_async(
|
def reset_async(
|
||||||
self,
|
self,
|
||||||
seed: Optional[Union[int, List[int]]] = None,
|
seed: Optional[Union[int, List[int]]] = None,
|
||||||
|
@@ -3,7 +3,6 @@ import numpy as np
|
|||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym.spaces import Box
|
from gym.spaces import Box
|
||||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import cv2
|
import cv2
|
||||||
@@ -38,7 +37,6 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
grayscale_obs: bool = True,
|
grayscale_obs: bool = True,
|
||||||
grayscale_newaxis: bool = False,
|
grayscale_newaxis: bool = False,
|
||||||
scale_obs: bool = False,
|
scale_obs: bool = False,
|
||||||
new_step_api: bool = False,
|
|
||||||
):
|
):
|
||||||
"""Wrapper for Atari 2600 preprocessing.
|
"""Wrapper for Atari 2600 preprocessing.
|
||||||
|
|
||||||
@@ -60,7 +58,7 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
DependencyNotInstalled: opencv-python package not installed
|
DependencyNotInstalled: opencv-python package not installed
|
||||||
ValueError: Disable frame-skipping in the original env
|
ValueError: Disable frame-skipping in the original env
|
||||||
"""
|
"""
|
||||||
super().__init__(env, new_step_api)
|
super().__init__(env)
|
||||||
if cv2 is None:
|
if cv2 is None:
|
||||||
raise gym.error.DependencyNotInstalled(
|
raise gym.error.DependencyNotInstalled(
|
||||||
"opencv-python package not installed, run `pip install gym[other]` to get dependencies for atari"
|
"opencv-python package not installed, run `pip install gym[other]` to get dependencies for atari"
|
||||||
@@ -119,9 +117,7 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
total_reward, terminated, truncated, info = 0.0, False, False, {}
|
total_reward, terminated, truncated, info = 0.0, False, False, {}
|
||||||
|
|
||||||
for t in range(self.frame_skip):
|
for t in range(self.frame_skip):
|
||||||
_, reward, terminated, truncated, info = step_api_compatibility(
|
_, reward, terminated, truncated, info = self.env.step(action)
|
||||||
self.env.step(action), True
|
|
||||||
)
|
|
||||||
total_reward += reward
|
total_reward += reward
|
||||||
self.game_over = terminated
|
self.game_over = terminated
|
||||||
|
|
||||||
@@ -143,10 +139,7 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
self.ale.getScreenGrayscale(self.obs_buffer[0])
|
self.ale.getScreenGrayscale(self.obs_buffer[0])
|
||||||
else:
|
else:
|
||||||
self.ale.getScreenRGB(self.obs_buffer[0])
|
self.ale.getScreenRGB(self.obs_buffer[0])
|
||||||
return step_api_compatibility(
|
return self._get_obs(), total_reward, terminated, truncated, info
|
||||||
(self._get_obs(), total_reward, terminated, truncated, info),
|
|
||||||
self.new_step_api,
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
"""Resets the environment using preprocessing."""
|
"""Resets the environment using preprocessing."""
|
||||||
@@ -159,9 +152,7 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
else 0
|
else 0
|
||||||
)
|
)
|
||||||
for _ in range(noops):
|
for _ in range(noops):
|
||||||
_, _, terminated, truncated, step_info = step_api_compatibility(
|
_, _, terminated, truncated, step_info = self.env.step(0)
|
||||||
self.env.step(0), True
|
|
||||||
)
|
|
||||||
reset_info.update(step_info)
|
reset_info.update(step_info)
|
||||||
if terminated or truncated:
|
if terminated or truncated:
|
||||||
_, reset_info = self.env.reset(**kwargs)
|
_, reset_info = self.env.reset(**kwargs)
|
||||||
|
@@ -1,6 +1,5 @@
|
|||||||
"""Wrapper that autoreset environments when `terminated=True` or `truncated=True`."""
|
"""Wrapper that autoreset environments when `terminated=True` or `truncated=True`."""
|
||||||
import gym
|
import gym
|
||||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
|
||||||
|
|
||||||
|
|
||||||
class AutoResetWrapper(gym.Wrapper):
|
class AutoResetWrapper(gym.Wrapper):
|
||||||
@@ -11,27 +10,27 @@ class AutoResetWrapper(gym.Wrapper):
|
|||||||
with new step API and ``(new_obs, final_reward, final_done, info)`` with the old step API.
|
with new step API and ``(new_obs, final_reward, final_done, info)`` with the old step API.
|
||||||
- ``new_obs`` is the first observation after calling :meth:`self.env.reset`
|
- ``new_obs`` is the first observation after calling :meth:`self.env.reset`
|
||||||
- ``final_reward`` is the reward after calling :meth:`self.env.step`, prior to calling :meth:`self.env.reset`.
|
- ``final_reward`` is the reward after calling :meth:`self.env.step`, prior to calling :meth:`self.env.reset`.
|
||||||
- ``final_done`` is always True. In the new API, either ``final_terminated`` or ``final_truncated`` is True
|
- ``final_terminated`` is the terminated value before calling :meth:`self.env.reset`.
|
||||||
|
- ``final_truncated`` is the truncated value before calling :meth:`self.env.reset`. Both `final_terminated` and `final_truncated` cannot be False.
|
||||||
- ``info`` is a dict containing all the keys from the info dict returned by the call to :meth:`self.env.reset`,
|
- ``info`` is a dict containing all the keys from the info dict returned by the call to :meth:`self.env.reset`,
|
||||||
with an additional key "final_observation" containing the observation returned by the last call to :meth:`self.env.step`
|
with an additional key "final_observation" containing the observation returned by the last call to :meth:`self.env.step`
|
||||||
and "final_info" containing the info dict returned by the last call to :meth:`self.env.step`.
|
and "final_info" containing the info dict returned by the last call to :meth:`self.env.step`.
|
||||||
|
|
||||||
Warning: When using this wrapper to collect rollouts, note that when :meth:`Env.step` returns `terminated` or `truncated`, a
|
Warning: When using this wrapper to collect rollouts, note that when :meth:`Env.step` returns `terminated` or `truncated`, a
|
||||||
new observation from after calling :meth:`Env.reset` is returned by :meth:`Env.step` alongside the
|
new observation from after calling :meth:`Env.reset` is returned by :meth:`Env.step` alongside the
|
||||||
final reward and done state from the previous episode.
|
final reward, terminated and truncated state from the previous episode.
|
||||||
If you need the final state from the previous episode, you need to retrieve it via the
|
If you need the final state from the previous episode, you need to retrieve it via the
|
||||||
"final_observation" key in the info dict.
|
"final_observation" key in the info dict.
|
||||||
Make sure you know what you're doing if you use this wrapper!
|
Make sure you know what you're doing if you use this wrapper!
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env, new_step_api: bool = False):
|
def __init__(self, env: gym.Env):
|
||||||
"""A class for providing an automatic reset functionality for gym environments when calling :meth:`self.step`.
|
"""A class for providing an automatic reset functionality for gym environments when calling :meth:`self.step`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env (gym.Env): The environment to apply the wrapper
|
env (gym.Env): The environment to apply the wrapper
|
||||||
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
|
||||||
"""
|
"""
|
||||||
super().__init__(env, new_step_api)
|
super().__init__(env)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered.
|
"""Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered.
|
||||||
@@ -42,10 +41,7 @@ class AutoResetWrapper(gym.Wrapper):
|
|||||||
Returns:
|
Returns:
|
||||||
The autoreset environment :meth:`step`
|
The autoreset environment :meth:`step`
|
||||||
"""
|
"""
|
||||||
obs, reward, terminated, truncated, info = step_api_compatibility(
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
self.env.step(action), True
|
|
||||||
)
|
|
||||||
|
|
||||||
if terminated or truncated:
|
if terminated or truncated:
|
||||||
|
|
||||||
new_obs, new_info = self.env.reset()
|
new_obs, new_info = self.env.reset()
|
||||||
@@ -62,6 +58,4 @@ class AutoResetWrapper(gym.Wrapper):
|
|||||||
obs = new_obs
|
obs = new_obs
|
||||||
info = new_info
|
info = new_info
|
||||||
|
|
||||||
return step_api_compatibility(
|
return obs, reward, terminated, truncated, info
|
||||||
(obs, reward, terminated, truncated, info), self.new_step_api
|
|
||||||
)
|
|
||||||
|
@@ -26,7 +26,7 @@ class ClipAction(ActionWrapper):
|
|||||||
env: The environment to apply the wrapper
|
env: The environment to apply the wrapper
|
||||||
"""
|
"""
|
||||||
assert isinstance(env.action_space, Box)
|
assert isinstance(env.action_space, Box)
|
||||||
super().__init__(env, new_step_api=True)
|
super().__init__(env)
|
||||||
|
|
||||||
def action(self, action):
|
def action(self, action):
|
||||||
"""Clips the action within the valid bounds.
|
"""Clips the action within the valid bounds.
|
||||||
|
@@ -15,7 +15,7 @@ class PassiveEnvChecker(gym.Wrapper):
|
|||||||
|
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
"""Initialises the wrapper with the environments, run the observation and action space tests."""
|
"""Initialises the wrapper with the environments, run the observation and action space tests."""
|
||||||
super().__init__(env, new_step_api=True)
|
super().__init__(env)
|
||||||
|
|
||||||
assert hasattr(
|
assert hasattr(
|
||||||
env, "action_space"
|
env, "action_space"
|
||||||
|
@@ -35,7 +35,7 @@ class FilterObservation(gym.ObservationWrapper):
|
|||||||
ValueError: If the environment's observation space is not :class:`spaces.Dict`
|
ValueError: If the environment's observation space is not :class:`spaces.Dict`
|
||||||
ValueError: If any of the `filter_keys` are not included in the original `env`'s observation space
|
ValueError: If any of the `filter_keys` are not included in the original `env`'s observation space
|
||||||
"""
|
"""
|
||||||
super().__init__(env, new_step_api=True)
|
super().__init__(env)
|
||||||
|
|
||||||
wrapped_observation_space = env.observation_space
|
wrapped_observation_space = env.observation_space
|
||||||
if not isinstance(wrapped_observation_space, spaces.Dict):
|
if not isinstance(wrapped_observation_space, spaces.Dict):
|
||||||
|
@@ -25,7 +25,7 @@ class FlattenObservation(gym.ObservationWrapper):
|
|||||||
Args:
|
Args:
|
||||||
env: The environment to apply the wrapper
|
env: The environment to apply the wrapper
|
||||||
"""
|
"""
|
||||||
super().__init__(env, new_step_api=True)
|
super().__init__(env)
|
||||||
self.observation_space = spaces.flatten_space(env.observation_space)
|
self.observation_space = spaces.flatten_space(env.observation_space)
|
||||||
|
|
||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
|
@@ -7,7 +7,6 @@ import numpy as np
|
|||||||
import gym
|
import gym
|
||||||
from gym.error import DependencyNotInstalled
|
from gym.error import DependencyNotInstalled
|
||||||
from gym.spaces import Box
|
from gym.spaces import Box
|
||||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
|
||||||
|
|
||||||
|
|
||||||
class LazyFrames:
|
class LazyFrames:
|
||||||
@@ -128,7 +127,6 @@ class FrameStack(gym.ObservationWrapper):
|
|||||||
env: gym.Env,
|
env: gym.Env,
|
||||||
num_stack: int,
|
num_stack: int,
|
||||||
lz4_compress: bool = False,
|
lz4_compress: bool = False,
|
||||||
new_step_api: bool = False,
|
|
||||||
):
|
):
|
||||||
"""Observation wrapper that stacks the observations in a rolling manner.
|
"""Observation wrapper that stacks the observations in a rolling manner.
|
||||||
|
|
||||||
@@ -136,9 +134,8 @@ class FrameStack(gym.ObservationWrapper):
|
|||||||
env (Env): The environment to apply the wrapper
|
env (Env): The environment to apply the wrapper
|
||||||
num_stack (int): The number of frames to stack
|
num_stack (int): The number of frames to stack
|
||||||
lz4_compress (bool): Use lz4 to compress the frames internally
|
lz4_compress (bool): Use lz4 to compress the frames internally
|
||||||
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
|
||||||
"""
|
"""
|
||||||
super().__init__(env, new_step_api)
|
super().__init__(env)
|
||||||
self.num_stack = num_stack
|
self.num_stack = num_stack
|
||||||
self.lz4_compress = lz4_compress
|
self.lz4_compress = lz4_compress
|
||||||
|
|
||||||
@@ -173,14 +170,9 @@ class FrameStack(gym.ObservationWrapper):
|
|||||||
Returns:
|
Returns:
|
||||||
Stacked observations, reward, terminated, truncated, and information from the environment
|
Stacked observations, reward, terminated, truncated, and information from the environment
|
||||||
"""
|
"""
|
||||||
observation, reward, terminated, truncated, info = step_api_compatibility(
|
observation, reward, terminated, truncated, info = self.env.step(action)
|
||||||
self.env.step(action), True
|
|
||||||
)
|
|
||||||
self.frames.append(observation)
|
self.frames.append(observation)
|
||||||
return step_api_compatibility(
|
return self.observation(None), reward, terminated, truncated, info
|
||||||
(self.observation(None), reward, terminated, truncated, info),
|
|
||||||
self.new_step_api,
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
"""Reset the environment with kwargs.
|
"""Reset the environment with kwargs.
|
||||||
|
@@ -28,7 +28,7 @@ class GrayScaleObservation(gym.ObservationWrapper):
|
|||||||
keep_dim (bool): If `True`, a singleton dimension will be added, i.e. observations are of the shape AxBx1.
|
keep_dim (bool): If `True`, a singleton dimension will be added, i.e. observations are of the shape AxBx1.
|
||||||
Otherwise, they are of shape AxB.
|
Otherwise, they are of shape AxB.
|
||||||
"""
|
"""
|
||||||
super().__init__(env, new_step_api=True)
|
super().__init__(env)
|
||||||
self.keep_dim = keep_dim
|
self.keep_dim = keep_dim
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
|
@@ -45,7 +45,7 @@ class HumanRendering(gym.Wrapper):
|
|||||||
Args:
|
Args:
|
||||||
env: The environment that is being wrapped
|
env: The environment that is being wrapped
|
||||||
"""
|
"""
|
||||||
super().__init__(env, new_step_api=True)
|
super().__init__(env)
|
||||||
assert env.render_mode in [
|
assert env.render_mode in [
|
||||||
"single_rgb_array",
|
"single_rgb_array",
|
||||||
"rgb_array",
|
"rgb_array",
|
||||||
|
@@ -2,7 +2,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
|
||||||
|
|
||||||
|
|
||||||
# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py
|
# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py
|
||||||
@@ -55,15 +54,14 @@ class NormalizeObservation(gym.core.Wrapper):
|
|||||||
newly instantiated or the policy was changed recently.
|
newly instantiated or the policy was changed recently.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env, epsilon: float = 1e-8, new_step_api: bool = False):
|
def __init__(self, env: gym.Env, epsilon: float = 1e-8):
|
||||||
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
|
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env (Env): The environment to apply the wrapper
|
env (Env): The environment to apply the wrapper
|
||||||
epsilon: A stability parameter that is used when scaling the observations.
|
epsilon: A stability parameter that is used when scaling the observations.
|
||||||
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
|
||||||
"""
|
"""
|
||||||
super().__init__(env, new_step_api)
|
super().__init__(env)
|
||||||
self.num_envs = getattr(env, "num_envs", 1)
|
self.num_envs = getattr(env, "num_envs", 1)
|
||||||
self.is_vector_env = getattr(env, "is_vector_env", False)
|
self.is_vector_env = getattr(env, "is_vector_env", False)
|
||||||
if self.is_vector_env:
|
if self.is_vector_env:
|
||||||
@@ -74,18 +72,12 @@ class NormalizeObservation(gym.core.Wrapper):
|
|||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""Steps through the environment and normalizes the observation."""
|
"""Steps through the environment and normalizes the observation."""
|
||||||
obs, rews, terminateds, truncateds, infos = step_api_compatibility(
|
obs, rews, terminateds, truncateds, infos = self.env.step(action)
|
||||||
self.env.step(action), True, self.is_vector_env
|
|
||||||
)
|
|
||||||
if self.is_vector_env:
|
if self.is_vector_env:
|
||||||
obs = self.normalize(obs)
|
obs = self.normalize(obs)
|
||||||
else:
|
else:
|
||||||
obs = self.normalize(np.array([obs]))[0]
|
obs = self.normalize(np.array([obs]))[0]
|
||||||
return step_api_compatibility(
|
return obs, rews, terminateds, truncateds, infos
|
||||||
(obs, rews, terminateds, truncateds, infos),
|
|
||||||
self.new_step_api,
|
|
||||||
self.is_vector_env,
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
"""Resets the environment and normalizes the observation."""
|
"""Resets the environment and normalizes the observation."""
|
||||||
@@ -117,7 +109,6 @@ class NormalizeReward(gym.core.Wrapper):
|
|||||||
env: gym.Env,
|
env: gym.Env,
|
||||||
gamma: float = 0.99,
|
gamma: float = 0.99,
|
||||||
epsilon: float = 1e-8,
|
epsilon: float = 1e-8,
|
||||||
new_step_api: bool = False,
|
|
||||||
):
|
):
|
||||||
"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance.
|
||||||
|
|
||||||
@@ -125,9 +116,8 @@ class NormalizeReward(gym.core.Wrapper):
|
|||||||
env (env): The environment to apply the wrapper
|
env (env): The environment to apply the wrapper
|
||||||
epsilon (float): A stability parameter
|
epsilon (float): A stability parameter
|
||||||
gamma (float): The discount factor that is used in the exponential moving average.
|
gamma (float): The discount factor that is used in the exponential moving average.
|
||||||
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
|
||||||
"""
|
"""
|
||||||
super().__init__(env, new_step_api)
|
super().__init__(env)
|
||||||
self.num_envs = getattr(env, "num_envs", 1)
|
self.num_envs = getattr(env, "num_envs", 1)
|
||||||
self.is_vector_env = getattr(env, "is_vector_env", False)
|
self.is_vector_env = getattr(env, "is_vector_env", False)
|
||||||
self.return_rms = RunningMeanStd(shape=())
|
self.return_rms = RunningMeanStd(shape=())
|
||||||
@@ -137,25 +127,16 @@ class NormalizeReward(gym.core.Wrapper):
|
|||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""Steps through the environment, normalizing the rewards returned."""
|
"""Steps through the environment, normalizing the rewards returned."""
|
||||||
obs, rews, terminateds, truncateds, infos = step_api_compatibility(
|
obs, rews, terminateds, truncateds, infos = self.env.step(action)
|
||||||
self.env.step(action), True, self.is_vector_env
|
|
||||||
)
|
|
||||||
if not self.is_vector_env:
|
if not self.is_vector_env:
|
||||||
rews = np.array([rews])
|
rews = np.array([rews])
|
||||||
self.returns = self.returns * self.gamma + rews
|
self.returns = self.returns * self.gamma + rews
|
||||||
rews = self.normalize(rews)
|
rews = self.normalize(rews)
|
||||||
if not self.is_vector_env:
|
dones = np.logical_or(terminateds, truncateds)
|
||||||
dones = terminateds or truncateds
|
|
||||||
else:
|
|
||||||
dones = np.bitwise_or(terminateds, truncateds)
|
|
||||||
self.returns[dones] = 0.0
|
self.returns[dones] = 0.0
|
||||||
if not self.is_vector_env:
|
if not self.is_vector_env:
|
||||||
rews = rews[0]
|
rews = rews[0]
|
||||||
return step_api_compatibility(
|
return obs, rews, terminateds, truncateds, infos
|
||||||
(obs, rews, terminateds, truncateds, infos),
|
|
||||||
self.new_step_api,
|
|
||||||
self.is_vector_env,
|
|
||||||
)
|
|
||||||
|
|
||||||
def normalize(self, rews):
|
def normalize(self, rews):
|
||||||
"""Normalizes the rewards with the running mean rewards and their variance."""
|
"""Normalizes the rewards with the running mean rewards and their variance."""
|
||||||
|
@@ -26,7 +26,7 @@ class OrderEnforcing(gym.Wrapper):
|
|||||||
env: The environment to wrap
|
env: The environment to wrap
|
||||||
disable_render_order_enforcing: If to disable render order enforcing
|
disable_render_order_enforcing: If to disable render order enforcing
|
||||||
"""
|
"""
|
||||||
super().__init__(env, new_step_api=True)
|
super().__init__(env)
|
||||||
self._has_reset: bool = False
|
self._has_reset: bool = False
|
||||||
self._disable_render_order_enforcing: bool = disable_render_order_enforcing
|
self._disable_render_order_enforcing: bool = disable_render_order_enforcing
|
||||||
|
|
||||||
|
@@ -77,7 +77,7 @@ class PixelObservationWrapper(gym.ObservationWrapper):
|
|||||||
specified ``pixel_keys``.
|
specified ``pixel_keys``.
|
||||||
TypeError: When an unexpected pixel type is used
|
TypeError: When an unexpected pixel type is used
|
||||||
"""
|
"""
|
||||||
super().__init__(env, new_step_api=True)
|
super().__init__(env)
|
||||||
|
|
||||||
# Avoid side-effects that occur when render_kwargs is manipulated
|
# Avoid side-effects that occur when render_kwargs is manipulated
|
||||||
render_kwargs = copy.deepcopy(render_kwargs)
|
render_kwargs = copy.deepcopy(render_kwargs)
|
||||||
|
@@ -6,7 +6,6 @@ from typing import Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
|
||||||
|
|
||||||
|
|
||||||
def add_vector_episode_statistics(
|
def add_vector_episode_statistics(
|
||||||
@@ -77,15 +76,14 @@ class RecordEpisodeStatistics(gym.Wrapper):
|
|||||||
length_queue: The lengths of the last ``deque_size``-many episodes
|
length_queue: The lengths of the last ``deque_size``-many episodes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env, deque_size: int = 100, new_step_api: bool = False):
|
def __init__(self, env: gym.Env, deque_size: int = 100):
|
||||||
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env (Env): The environment to apply the wrapper
|
env (Env): The environment to apply the wrapper
|
||||||
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
|
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
|
||||||
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
|
||||||
"""
|
"""
|
||||||
super().__init__(env, new_step_api)
|
super().__init__(env)
|
||||||
self.num_envs = getattr(env, "num_envs", 1)
|
self.num_envs = getattr(env, "num_envs", 1)
|
||||||
self.t0 = time.perf_counter()
|
self.t0 = time.perf_counter()
|
||||||
self.episode_count = 0
|
self.episode_count = 0
|
||||||
@@ -110,7 +108,7 @@ class RecordEpisodeStatistics(gym.Wrapper):
|
|||||||
terminateds,
|
terminateds,
|
||||||
truncateds,
|
truncateds,
|
||||||
infos,
|
infos,
|
||||||
) = step_api_compatibility(self.env.step(action), True, self.is_vector_env)
|
) = self.env.step(action)
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
infos, dict
|
infos, dict
|
||||||
), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order."
|
), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order."
|
||||||
@@ -144,14 +142,10 @@ class RecordEpisodeStatistics(gym.Wrapper):
|
|||||||
self.episode_count += 1
|
self.episode_count += 1
|
||||||
self.episode_returns[i] = 0
|
self.episode_returns[i] = 0
|
||||||
self.episode_lengths[i] = 0
|
self.episode_lengths[i] = 0
|
||||||
return step_api_compatibility(
|
return (
|
||||||
(
|
|
||||||
observations,
|
observations,
|
||||||
rewards,
|
rewards,
|
||||||
terminateds if self.is_vector_env else terminateds[0],
|
terminateds if self.is_vector_env else terminateds[0],
|
||||||
truncateds if self.is_vector_env else truncateds[0],
|
truncateds if self.is_vector_env else truncateds[0],
|
||||||
infos,
|
infos,
|
||||||
),
|
|
||||||
self.new_step_api,
|
|
||||||
self.is_vector_env,
|
|
||||||
)
|
)
|
||||||
|
@@ -4,7 +4,6 @@ from typing import Callable, Optional
|
|||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym import logger
|
from gym import logger
|
||||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
|
||||||
from gym.wrappers.monitoring import video_recorder
|
from gym.wrappers.monitoring import video_recorder
|
||||||
|
|
||||||
|
|
||||||
@@ -46,7 +45,6 @@ class RecordVideo(gym.Wrapper):
|
|||||||
step_trigger: Callable[[int], bool] = None,
|
step_trigger: Callable[[int], bool] = None,
|
||||||
video_length: int = 0,
|
video_length: int = 0,
|
||||||
name_prefix: str = "rl-video",
|
name_prefix: str = "rl-video",
|
||||||
new_step_api: bool = False,
|
|
||||||
):
|
):
|
||||||
"""Wrapper records videos of rollouts.
|
"""Wrapper records videos of rollouts.
|
||||||
|
|
||||||
@@ -58,9 +56,8 @@ class RecordVideo(gym.Wrapper):
|
|||||||
video_length (int): The length of recorded episodes. If 0, entire episodes are recorded.
|
video_length (int): The length of recorded episodes. If 0, entire episodes are recorded.
|
||||||
Otherwise, snippets of the specified length are captured
|
Otherwise, snippets of the specified length are captured
|
||||||
name_prefix (str): Will be prepended to the filename of the recordings
|
name_prefix (str): Will be prepended to the filename of the recordings
|
||||||
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
|
||||||
"""
|
"""
|
||||||
super().__init__(env, new_step_api)
|
super().__init__(env)
|
||||||
|
|
||||||
if episode_trigger is None and step_trigger is None:
|
if episode_trigger is None and step_trigger is None:
|
||||||
episode_trigger = capped_cubic_video_schedule
|
episode_trigger = capped_cubic_video_schedule
|
||||||
@@ -142,7 +139,7 @@ class RecordVideo(gym.Wrapper):
|
|||||||
terminateds,
|
terminateds,
|
||||||
truncateds,
|
truncateds,
|
||||||
infos,
|
infos,
|
||||||
) = step_api_compatibility(self.env.step(action), True, self.is_vector_env)
|
) = self.env.step(action)
|
||||||
|
|
||||||
if not (self.terminated or self.truncated):
|
if not (self.terminated or self.truncated):
|
||||||
# increment steps and episodes
|
# increment steps and episodes
|
||||||
@@ -174,11 +171,7 @@ class RecordVideo(gym.Wrapper):
|
|||||||
elif self._video_enabled():
|
elif self._video_enabled():
|
||||||
self.start_video_recorder()
|
self.start_video_recorder()
|
||||||
|
|
||||||
return step_api_compatibility(
|
return observations, rewards, terminateds, truncateds, infos
|
||||||
(observations, rewards, terminateds, truncateds, infos),
|
|
||||||
self.new_step_api,
|
|
||||||
self.is_vector_env,
|
|
||||||
)
|
|
||||||
|
|
||||||
def close_video_recorder(self):
|
def close_video_recorder(self):
|
||||||
"""Closes the video recorder if currently recording."""
|
"""Closes the video recorder if currently recording."""
|
||||||
|
@@ -45,7 +45,7 @@ class RescaleAction(gym.ActionWrapper):
|
|||||||
), f"expected Box action space, got {type(env.action_space)}"
|
), f"expected Box action space, got {type(env.action_space)}"
|
||||||
assert np.less_equal(min_action, max_action).all(), (min_action, max_action)
|
assert np.less_equal(min_action, max_action).all(), (min_action, max_action)
|
||||||
|
|
||||||
super().__init__(env, new_step_api=True)
|
super().__init__(env)
|
||||||
self.min_action = (
|
self.min_action = (
|
||||||
np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action
|
np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action
|
||||||
)
|
)
|
||||||
|
@@ -32,7 +32,7 @@ class ResizeObservation(gym.ObservationWrapper):
|
|||||||
env: The environment to apply the wrapper
|
env: The environment to apply the wrapper
|
||||||
shape: The shape of the resized observations
|
shape: The shape of the resized observations
|
||||||
"""
|
"""
|
||||||
super().__init__(env, new_step_api=True)
|
super().__init__(env)
|
||||||
if isinstance(shape, int):
|
if isinstance(shape, int):
|
||||||
shape = (shape, shape)
|
shape = (shape, shape)
|
||||||
assert all(x > 0 for x in shape), shape
|
assert all(x > 0 for x in shape), shape
|
||||||
|
@@ -1,7 +1,10 @@
|
|||||||
"""Implementation of StepAPICompatibility wrapper class for transforming envs between new and old step API."""
|
"""Implementation of StepAPICompatibility wrapper class for transforming envs between new and old step API."""
|
||||||
import gym
|
import gym
|
||||||
from gym.logger import deprecation
|
from gym.logger import deprecation
|
||||||
from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api
|
from gym.utils.step_api_compatibility import (
|
||||||
|
convert_to_done_step_api,
|
||||||
|
convert_to_terminated_truncated_step_api,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StepAPICompatibility(gym.Wrapper):
|
class StepAPICompatibility(gym.Wrapper):
|
||||||
@@ -11,37 +14,36 @@ class StepAPICompatibility(gym.Wrapper):
|
|||||||
New step API refers to step() method returning (observation, reward, terminated, truncated, info)
|
New step API refers to step() method returning (observation, reward, terminated, truncated, info)
|
||||||
(Refer to docs for details on the API change)
|
(Refer to docs for details on the API change)
|
||||||
|
|
||||||
This wrapper is to be used to ease transition to new API and for backward compatibility.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env (gym.Env): the env to wrap. Can be in old or new API
|
env (gym.Env): the env to wrap. Can be in old or new API
|
||||||
new_step_api (bool): True to use env with new step API, False to use env with old step API. (False by default)
|
apply_step_compatibility (bool): Apply to convert environment to use new step API that returns two bools. (False by default)
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> env = gym.make("CartPole-v1")
|
>>> env = gym.make("CartPole-v1")
|
||||||
>>> env # wrapper applied by default, set to old API
|
>>> env # wrapper not applied by default, set to new API
|
||||||
<TimeLimit<OrderEnforcing<StepAPICompatibility<CartPoleEnv<CartPole-v1>>>>>
|
<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>
|
||||||
>>> env = gym.make("CartPole-v1", new_step_api=True) # set to new API
|
>>> env = gym.make("CartPole-v1", apply_step_compatibility=True) # set to old API
|
||||||
>>> env = StepAPICompatibility(CustomEnv(), new_step_api=True) # manually using wrapper on unregistered envs
|
<StepAPICompatibility<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>>
|
||||||
|
>>> env = StepAPICompatibility(CustomEnv(), apply_step_compatibility=False) # manually using wrapper on unregistered envs
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env, new_step_api=False):
|
def __init__(self, env: gym.Env, output_truncation_bool: bool = True):
|
||||||
"""A wrapper which can transform an environment from new step API to old and vice-versa.
|
"""A wrapper which can transform an environment from new step API to old and vice-versa.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env (gym.Env): the env to wrap. Can be in old or new API
|
env (gym.Env): the env to wrap. Can be in old or new API
|
||||||
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
output_truncation_bool (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
||||||
"""
|
"""
|
||||||
super().__init__(env, new_step_api)
|
super().__init__(env)
|
||||||
self.new_step_api = new_step_api
|
self.output_truncation_bool = output_truncation_bool
|
||||||
if not self.new_step_api:
|
if not self.output_truncation_bool:
|
||||||
deprecation(
|
deprecation(
|
||||||
"Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future."
|
"Initializing environment in old step API which returns one bool instead of two."
|
||||||
)
|
)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""Steps through the environment, returning 5 or 4 items depending on `new_step_api`.
|
"""Steps through the environment, returning 5 or 4 items depending on `apply_step_compatibility`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
action: action to step through the environment with
|
action: action to step through the environment with
|
||||||
@@ -50,7 +52,7 @@ class StepAPICompatibility(gym.Wrapper):
|
|||||||
(observation, reward, terminated, truncated, info) or (observation, reward, done, info)
|
(observation, reward, terminated, truncated, info) or (observation, reward, done, info)
|
||||||
"""
|
"""
|
||||||
step_returns = self.env.step(action)
|
step_returns = self.env.step(action)
|
||||||
if self.new_step_api:
|
if self.output_truncation_bool:
|
||||||
return step_to_new_api(step_returns)
|
return convert_to_terminated_truncated_step_api(step_returns)
|
||||||
else:
|
else:
|
||||||
return step_to_old_api(step_returns)
|
return convert_to_done_step_api(step_returns)
|
||||||
|
@@ -3,7 +3,6 @@ import numpy as np
|
|||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym.spaces import Box
|
from gym.spaces import Box
|
||||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
|
||||||
|
|
||||||
|
|
||||||
class TimeAwareObservation(gym.ObservationWrapper):
|
class TimeAwareObservation(gym.ObservationWrapper):
|
||||||
@@ -22,14 +21,13 @@ class TimeAwareObservation(gym.ObservationWrapper):
|
|||||||
array([ 0.03881167, -0.16021058, 0.0220928 , 0.28875574, 1. ])
|
array([ 0.03881167, -0.16021058, 0.0220928 , 0.28875574, 1. ])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env, new_step_api: bool = False):
|
def __init__(self, env: gym.Env):
|
||||||
"""Initialize :class:`TimeAwareObservation` that requires an environment with a flat :class:`Box` observation space.
|
"""Initialize :class:`TimeAwareObservation` that requires an environment with a flat :class:`Box` observation space.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env: The environment to apply the wrapper
|
env: The environment to apply the wrapper
|
||||||
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
|
||||||
"""
|
"""
|
||||||
super().__init__(env, new_step_api)
|
super().__init__(env)
|
||||||
assert isinstance(env.observation_space, Box)
|
assert isinstance(env.observation_space, Box)
|
||||||
assert env.observation_space.dtype == np.float32
|
assert env.observation_space.dtype == np.float32
|
||||||
low = np.append(self.observation_space.low, 0.0)
|
low = np.append(self.observation_space.low, 0.0)
|
||||||
@@ -58,9 +56,7 @@ class TimeAwareObservation(gym.ObservationWrapper):
|
|||||||
The environment's step using the action.
|
The environment's step using the action.
|
||||||
"""
|
"""
|
||||||
self.t += 1
|
self.t += 1
|
||||||
return step_api_compatibility(
|
return super().step(action)
|
||||||
super().step(action), self.new_step_api, self.is_vector_env
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
"""Reset the environment setting the time to zero.
|
"""Reset the environment setting the time to zero.
|
||||||
|
@@ -2,7 +2,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
|
||||||
|
|
||||||
|
|
||||||
class TimeLimit(gym.Wrapper):
|
class TimeLimit(gym.Wrapper):
|
||||||
@@ -11,12 +10,6 @@ class TimeLimit(gym.Wrapper):
|
|||||||
If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued.
|
If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued.
|
||||||
Critically, this is different from the `terminated` signal that originates from the underlying environment as part of the MDP.
|
Critically, this is different from the `terminated` signal that originates from the underlying environment as part of the MDP.
|
||||||
|
|
||||||
(deprecated)
|
|
||||||
This information is passed through ``info`` that is returned when `done`-signal was issued.
|
|
||||||
The done-signal originates from the time limit (i.e. it signifies a *truncation*) if and only if
|
|
||||||
the key `"TimeLimit.truncated"` exists in ``info`` and the corresponding value is ``True``. This will be removed in favour
|
|
||||||
of only issuing a `truncated` signal in future versions.
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> from gym.envs.classic_control import CartPoleEnv
|
>>> from gym.envs.classic_control import CartPoleEnv
|
||||||
>>> from gym.wrappers import TimeLimit
|
>>> from gym.wrappers import TimeLimit
|
||||||
@@ -28,16 +21,14 @@ class TimeLimit(gym.Wrapper):
|
|||||||
self,
|
self,
|
||||||
env: gym.Env,
|
env: gym.Env,
|
||||||
max_episode_steps: Optional[int] = None,
|
max_episode_steps: Optional[int] = None,
|
||||||
new_step_api: bool = False,
|
|
||||||
):
|
):
|
||||||
"""Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.
|
"""Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env: The environment to apply the wrapper
|
env: The environment to apply the wrapper
|
||||||
max_episode_steps: An optional max episode steps (if ``None``, ``env.spec.max_episode_steps`` is used)
|
max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used)
|
||||||
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
|
||||||
"""
|
"""
|
||||||
super().__init__(env, new_step_api)
|
super().__init__(env)
|
||||||
if max_episode_steps is None and self.env.spec is not None:
|
if max_episode_steps is None and self.env.spec is not None:
|
||||||
max_episode_steps = env.spec.max_episode_steps
|
max_episode_steps = env.spec.max_episode_steps
|
||||||
if self.env.spec is not None:
|
if self.env.spec is not None:
|
||||||
@@ -52,26 +43,17 @@ class TimeLimit(gym.Wrapper):
|
|||||||
action: The environment step action
|
action: The environment step action
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The environment step ``(observation, reward, done, info)`` with "TimeLimit.truncated"=True
|
The environment step ``(observation, reward, terminated, truncated, info)`` with `truncated=True`
|
||||||
when truncated (the number of steps elapsed >= max episode steps) or
|
if the number of steps elapsed >= max episode steps
|
||||||
"TimeLimit.truncated"=False if the environment terminated
|
|
||||||
"""
|
"""
|
||||||
observation, reward, terminated, truncated, info = step_api_compatibility(
|
observation, reward, terminated, truncated, info = self.env.step(action)
|
||||||
self.env.step(action),
|
|
||||||
True,
|
|
||||||
)
|
|
||||||
self._elapsed_steps += 1
|
self._elapsed_steps += 1
|
||||||
|
|
||||||
if self._elapsed_steps >= self._max_episode_steps:
|
if self._elapsed_steps >= self._max_episode_steps:
|
||||||
if self.new_step_api is True or terminated is False:
|
|
||||||
# As the old step api cannot encode both terminated and truncated, we favor terminated in the case of both.
|
|
||||||
# Therefore, if new step api (i.e. not old step api) or when terminated is False to prevent the overriding
|
|
||||||
truncated = True
|
truncated = True
|
||||||
|
|
||||||
return step_api_compatibility(
|
return observation, reward, terminated, truncated, info
|
||||||
(observation, reward, terminated, truncated, info),
|
|
||||||
self.new_step_api,
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
"""Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.
|
"""Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.
|
||||||
|
@@ -27,7 +27,7 @@ class TransformObservation(gym.ObservationWrapper):
|
|||||||
env: The environment to apply the wrapper
|
env: The environment to apply the wrapper
|
||||||
f: A function that transforms the observation
|
f: A function that transforms the observation
|
||||||
"""
|
"""
|
||||||
super().__init__(env, new_step_api=True)
|
super().__init__(env)
|
||||||
assert callable(f)
|
assert callable(f)
|
||||||
self.f = f
|
self.f = f
|
||||||
|
|
||||||
|
@@ -28,7 +28,7 @@ class TransformReward(RewardWrapper):
|
|||||||
env: The environment to apply the wrapper
|
env: The environment to apply the wrapper
|
||||||
f: A function that transforms the reward
|
f: A function that transforms the reward
|
||||||
"""
|
"""
|
||||||
super().__init__(env, new_step_api=True)
|
super().__init__(env)
|
||||||
assert callable(f)
|
assert callable(f)
|
||||||
self.f = f
|
self.f = f
|
||||||
|
|
||||||
|
@@ -3,7 +3,6 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
|
||||||
|
|
||||||
|
|
||||||
class VectorListInfo(gym.Wrapper):
|
class VectorListInfo(gym.Wrapper):
|
||||||
@@ -30,30 +29,23 @@ class VectorListInfo(gym.Wrapper):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env, new_step_api=False):
|
def __init__(self, env):
|
||||||
"""This wrapper will convert the info into the list format.
|
"""This wrapper will convert the info into the list format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env (Env): The environment to apply the wrapper
|
env (Env): The environment to apply the wrapper
|
||||||
new_step_api (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
|
|
||||||
"""
|
"""
|
||||||
assert getattr(
|
assert getattr(
|
||||||
env, "is_vector_env", False
|
env, "is_vector_env", False
|
||||||
), "This wrapper can only be used in vectorized environments."
|
), "This wrapper can only be used in vectorized environments."
|
||||||
super().__init__(env, new_step_api)
|
super().__init__(env)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""Steps through the environment, convert dict info to list."""
|
"""Steps through the environment, convert dict info to list."""
|
||||||
observation, reward, terminated, truncated, infos = step_api_compatibility(
|
observation, reward, terminated, truncated, infos = self.env.step(action)
|
||||||
self.env.step(action), True, True
|
|
||||||
)
|
|
||||||
list_info = self._convert_info_to_list(infos)
|
list_info = self._convert_info_to_list(infos)
|
||||||
|
|
||||||
return step_api_compatibility(
|
return observation, reward, terminated, truncated, list_info
|
||||||
(observation, reward, terminated, truncated, list_info),
|
|
||||||
self.new_step_api,
|
|
||||||
True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
"""Resets the environment using kwargs."""
|
"""Resets the environment using kwargs."""
|
||||||
|
@@ -112,14 +112,12 @@ def test_box_actions_out_of_bound(env: gym.Env):
|
|||||||
zip(env.action_space.bounded_above, env.action_space.bounded_below)
|
zip(env.action_space.bounded_above, env.action_space.bounded_below)
|
||||||
):
|
):
|
||||||
if is_upper_bound:
|
if is_upper_bound:
|
||||||
obs, _, _, _, _ = env.step(
|
obs, _, _, _, _ = env.step(upper_bounds)
|
||||||
upper_bounds
|
|
||||||
) # `env` is unwrapped, and in new step API
|
|
||||||
oob_action = upper_bounds.copy()
|
oob_action = upper_bounds.copy()
|
||||||
oob_action[i] += np.cast[dtype](OOB_VALUE)
|
oob_action[i] += np.cast[dtype](OOB_VALUE)
|
||||||
|
|
||||||
assert oob_action[i] > upper_bounds[i]
|
assert oob_action[i] > upper_bounds[i]
|
||||||
oob_obs, _, _, _ = oob_env.step(oob_action)
|
oob_obs, _, _, _, _ = oob_env.step(oob_action)
|
||||||
|
|
||||||
assert np.alltrue(obs == oob_obs)
|
assert np.alltrue(obs == oob_obs)
|
||||||
|
|
||||||
@@ -131,7 +129,7 @@ def test_box_actions_out_of_bound(env: gym.Env):
|
|||||||
oob_action[i] -= np.cast[dtype](OOB_VALUE)
|
oob_action[i] -= np.cast[dtype](OOB_VALUE)
|
||||||
|
|
||||||
assert oob_action[i] < lower_bounds[i]
|
assert oob_action[i] < lower_bounds[i]
|
||||||
oob_obs, _, _, _ = oob_env.step(oob_action)
|
oob_obs, _, _, _, _ = oob_env.step(oob_action)
|
||||||
|
|
||||||
assert np.alltrue(obs == oob_obs)
|
assert np.alltrue(obs == oob_obs)
|
||||||
|
|
||||||
|
@@ -18,8 +18,7 @@ PASSIVE_CHECK_IGNORE_WARNING = [
|
|||||||
f"\x1b[33mWARN: {message}\x1b[0m"
|
f"\x1b[33mWARN: {message}\x1b[0m"
|
||||||
for message in [
|
for message in [
|
||||||
"This version of the mujoco environments depends on the mujoco-py bindings, which are no longer maintained and may stop working. Please upgrade to the v4 versions of the environments (which depend on the mujoco python bindings instead), unless you are trying to precisely replicate previous works).",
|
"This version of the mujoco environments depends on the mujoco-py bindings, which are no longer maintained and may stop working. Please upgrade to the v4 versions of the environments (which depend on the mujoco python bindings instead), unless you are trying to precisely replicate previous works).",
|
||||||
"Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.",
|
"Initializing environment in done (old) step API which returns one bool instead of two.",
|
||||||
"Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.",
|
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -30,8 +29,7 @@ CHECK_ENV_IGNORE_WARNINGS = [
|
|||||||
"A Box observation space minimum value is -infinity. This is probably too low.",
|
"A Box observation space minimum value is -infinity. This is probably too low.",
|
||||||
"A Box observation space maximum value is -infinity. This is probably too high.",
|
"A Box observation space maximum value is -infinity. This is probably too high.",
|
||||||
"For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.",
|
"For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.",
|
||||||
"Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.",
|
"Initializing environment in done (old) step API which returns one bool instead of two.",
|
||||||
"Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.",
|
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -92,8 +90,8 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
|
|||||||
# We don't evaluate the determinism of actions
|
# We don't evaluate the determinism of actions
|
||||||
action = env_1.action_space.sample()
|
action = env_1.action_space.sample()
|
||||||
|
|
||||||
obs_1, rew_1, done_1, info_1 = env_1.step(action)
|
obs_1, rew_1, terminated_1, truncated_1, info_1 = env_1.step(action)
|
||||||
obs_2, rew_2, done_2, info_2 = env_2.step(action)
|
obs_2, rew_2, terminated_2, truncated_2, info_2 = env_2.step(action)
|
||||||
|
|
||||||
assert_equals(obs_1, obs_2, f"[{time_step}] ")
|
assert_equals(obs_1, obs_2, f"[{time_step}] ")
|
||||||
assert env_1.observation_space.contains(
|
assert env_1.observation_space.contains(
|
||||||
@@ -101,10 +99,17 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
|
|||||||
) # obs_2 verified by previous assertion
|
) # obs_2 verified by previous assertion
|
||||||
|
|
||||||
assert rew_1 == rew_2, f"[{time_step}] reward 1={rew_1}, reward 2={rew_2}"
|
assert rew_1 == rew_2, f"[{time_step}] reward 1={rew_1}, reward 2={rew_2}"
|
||||||
assert done_1 == done_2, f"[{time_step}] done 1={done_1}, done 2={done_2}"
|
assert (
|
||||||
|
terminated_1 == terminated_2
|
||||||
|
), f"[{time_step}] done 1={terminated_1}, done 2={terminated_2}"
|
||||||
|
assert (
|
||||||
|
truncated_1 == truncated_2
|
||||||
|
), f"[{time_step}] done 1={truncated_1}, done 2={truncated_2}"
|
||||||
assert_equals(info_1, info_2, f"[{time_step}] ")
|
assert_equals(info_1, info_2, f"[{time_step}] ")
|
||||||
|
|
||||||
if done_1: # done_2 verified by previous assertion
|
if (
|
||||||
|
terminated_1 or truncated_1
|
||||||
|
): # terminated_2, truncated_2 verified by previous assertion
|
||||||
env_1.reset(seed=SEED)
|
env_1.reset(seed=SEED)
|
||||||
env_2.reset(seed=SEED)
|
env_2.reset(seed=SEED)
|
||||||
|
|
||||||
|
@@ -24,17 +24,22 @@ def verify_environments_match(
|
|||||||
|
|
||||||
for i in range(num_actions):
|
for i in range(num_actions):
|
||||||
action = old_env.action_space.sample()
|
action = old_env.action_space.sample()
|
||||||
old_obs, old_reward, old_done, old_info = old_env.step(action)
|
old_obs, old_reward, old_terminated, old_truncated, old_info = old_env.step(
|
||||||
new_obs, new_reward, new_done, new_info = new_env.step(action)
|
action
|
||||||
|
)
|
||||||
|
new_obs, new_reward, new_terminated, new_truncated, new_info = new_env.step(
|
||||||
|
action
|
||||||
|
)
|
||||||
|
|
||||||
np.testing.assert_allclose(old_obs, new_obs, atol=EPS)
|
np.testing.assert_allclose(old_obs, new_obs, atol=EPS)
|
||||||
np.testing.assert_allclose(old_reward, new_reward, atol=EPS)
|
np.testing.assert_allclose(old_reward, new_reward, atol=EPS)
|
||||||
np.testing.assert_equal(old_done, new_done)
|
np.testing.assert_equal(old_terminated, new_terminated)
|
||||||
|
np.testing.assert_equal(old_truncated, new_truncated)
|
||||||
|
|
||||||
for key in old_info:
|
for key in old_info:
|
||||||
np.testing.assert_allclose(old_info[key], new_info[key], atol=EPS)
|
np.testing.assert_allclose(old_info[key], new_info[key], atol=EPS)
|
||||||
|
|
||||||
if old_done:
|
if old_terminated or old_truncated:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
@@ -62,7 +67,7 @@ def test_obs_space_mujoco_environments(env_spec: EnvSpec):
|
|||||||
), f"Obseravtion returned by reset() of {env_spec.id} is not contained in the default observation space {env.observation_space}."
|
), f"Obseravtion returned by reset() of {env_spec.id} is not contained in the default observation space {env.observation_space}."
|
||||||
|
|
||||||
action = env.action_space.sample()
|
action = env.action_space.sample()
|
||||||
step_obs, _, _, _ = env.step(action)
|
step_obs, _, _, _, _ = env.step(action)
|
||||||
assert env.observation_space.contains(
|
assert env.observation_space.contains(
|
||||||
step_obs
|
step_obs
|
||||||
), f"Obseravtion returned by step(action) of {env_spec.id} is not contained in the default observation space {env.observation_space}."
|
), f"Obseravtion returned by step(action) of {env_spec.id} is not contained in the default observation space {env.observation_space}."
|
||||||
@@ -78,7 +83,7 @@ def test_obs_space_mujoco_environments(env_spec: EnvSpec):
|
|||||||
reset_obs
|
reset_obs
|
||||||
), f"Obseravtion of {env_spec.id} is not contained in the default observation space {env.observation_space} when excluding current position from observation."
|
), f"Obseravtion of {env_spec.id} is not contained in the default observation space {env.observation_space} when excluding current position from observation."
|
||||||
|
|
||||||
step_obs, _, _, _ = env.step(action)
|
step_obs, _, _, _, _ = env.step(action)
|
||||||
assert env.observation_space.contains(
|
assert env.observation_space.contains(
|
||||||
step_obs
|
step_obs
|
||||||
), f"Obseravtion returned by step(action) of {env_spec.id} is not contained in the default observation space {env.observation_space} when excluding current position from observation."
|
), f"Obseravtion returned by step(action) of {env_spec.id} is not contained in the default observation space {env.observation_space} when excluding current position from observation."
|
||||||
@@ -91,7 +96,7 @@ def test_obs_space_mujoco_environments(env_spec: EnvSpec):
|
|||||||
reset_obs
|
reset_obs
|
||||||
), f"Obseravtion of {env_spec.id} is not contained in the default observation space {env.observation_space} when using contact forces."
|
), f"Obseravtion of {env_spec.id} is not contained in the default observation space {env.observation_space} when using contact forces."
|
||||||
|
|
||||||
step_obs, _, _, _ = env.step(action)
|
step_obs, _, _, _, _ = env.step(action)
|
||||||
assert env.observation_space.contains(
|
assert env.observation_space.contains(
|
||||||
step_obs
|
step_obs
|
||||||
), f"Obseravtion returned by step(action) of {env_spec.id} is not contained in the default observation space {env.observation_space} when using contact forces."
|
), f"Obseravtion returned by step(action) of {env_spec.id} is not contained in the default observation space {env.observation_space} when using contact forces."
|
||||||
|
@@ -27,8 +27,8 @@ class DummyPlayEnv(gym.Env):
|
|||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
obs = np.zeros((1, 1))
|
obs = np.zeros((1, 1))
|
||||||
rew, done, info = 1, False, {}
|
rew, terminated, truncated, info = 1, False, False, {}
|
||||||
return obs, rew, done, info
|
return obs, rew, terminated, truncated, info
|
||||||
|
|
||||||
def reset(self, seed=None):
|
def reset(self, seed=None):
|
||||||
...
|
...
|
||||||
@@ -52,9 +52,9 @@ class PlayStatus:
|
|||||||
self.cumulative_reward = 0
|
self.cumulative_reward = 0
|
||||||
self.last_observation = None
|
self.last_observation = None
|
||||||
|
|
||||||
def callback(self, obs_t, obs_tp1, action, rew, done, info):
|
def callback(self, obs_t, obs_tp1, action, rew, terminated, truncated, info):
|
||||||
_, obs_tp1, _, rew, _, _ = self.data_callback(
|
_, obs_tp1, _, rew, _, _, _ = self.data_callback(
|
||||||
obs_t, obs_tp1, action, rew, done, info
|
obs_t, obs_tp1, action, rew, terminated, truncated, info
|
||||||
)
|
)
|
||||||
self.cumulative_reward += rew
|
self.cumulative_reward += rew
|
||||||
self.last_observation = obs_tp1
|
self.last_observation = obs_tp1
|
||||||
@@ -177,7 +177,7 @@ def test_play_loop_real_env():
|
|||||||
]
|
]
|
||||||
keydown_events = [k for k in callback_events if k.type == KEYDOWN]
|
keydown_events = [k for k in callback_events if k.type == KEYDOWN]
|
||||||
|
|
||||||
def callback(obs_t, obs_tp1, action, rew, done, info):
|
def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info):
|
||||||
pygame_event = callback_events.pop(0)
|
pygame_event = callback_events.pop(0)
|
||||||
event.post(pygame_event)
|
event.post(pygame_event)
|
||||||
|
|
||||||
@@ -187,7 +187,7 @@ def test_play_loop_real_env():
|
|||||||
pygame_event = callback_events.pop(0)
|
pygame_event = callback_events.pop(0)
|
||||||
event.post(pygame_event)
|
event.post(pygame_event)
|
||||||
|
|
||||||
return obs_t, obs_tp1, action, rew, done, info
|
return obs_t, obs_tp1, action, rew, terminated, truncated, info
|
||||||
|
|
||||||
env = gym.make(ENV, render_mode="single_rgb_array", disable_env_checker=True)
|
env = gym.make(ENV, render_mode="single_rgb_array", disable_env_checker=True)
|
||||||
env.reset(seed=SEED)
|
env.reset(seed=SEED)
|
||||||
@@ -197,10 +197,10 @@ def test_play_loop_real_env():
|
|||||||
|
|
||||||
# first action is 0 because at the first iteration
|
# first action is 0 because at the first iteration
|
||||||
# we can not inject a callback event into play()
|
# we can not inject a callback event into play()
|
||||||
obs, _, _, _ = env.step(0)
|
obs, _, _, _, _ = env.step(0)
|
||||||
for e in keydown_events:
|
for e in keydown_events:
|
||||||
action = keys_to_action[chr(e.key) if str_keys else (e.key,)]
|
action = keys_to_action[chr(e.key) if str_keys else (e.key,)]
|
||||||
obs, _, _, _ = env.step(action)
|
obs, _, _, _, _ = env.step(action)
|
||||||
|
|
||||||
env_play = gym.make(
|
env_play = gym.make(
|
||||||
ENV, render_mode="single_rgb_array", disable_env_checker=True
|
ENV, render_mode="single_rgb_array", disable_env_checker=True
|
||||||
|
@@ -2,7 +2,10 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gym.utils.env_checker import data_equivalence
|
from gym.utils.env_checker import data_equivalence
|
||||||
from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api
|
from gym.utils.step_api_compatibility import (
|
||||||
|
convert_to_done_step_api,
|
||||||
|
convert_to_terminated_truncated_step_api,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -54,7 +57,7 @@ from gym.utils.step_api_compatibility import step_to_new_api, step_to_old_api
|
|||||||
def test_to_done_step_api(
|
def test_to_done_step_api(
|
||||||
is_vector_env, done_returns, expected_terminated, expected_truncated
|
is_vector_env, done_returns, expected_terminated, expected_truncated
|
||||||
):
|
):
|
||||||
_, _, terminated, truncated, info = step_to_new_api(
|
_, _, terminated, truncated, info = convert_to_terminated_truncated_step_api(
|
||||||
done_returns, is_vector_env=is_vector_env
|
done_returns, is_vector_env=is_vector_env
|
||||||
)
|
)
|
||||||
assert np.all(terminated == expected_terminated)
|
assert np.all(terminated == expected_terminated)
|
||||||
@@ -67,7 +70,7 @@ def test_to_done_step_api(
|
|||||||
else: # isinstance(info, dict)
|
else: # isinstance(info, dict)
|
||||||
assert "TimeLimit.truncated" not in info
|
assert "TimeLimit.truncated" not in info
|
||||||
|
|
||||||
roundtripped_returns = step_to_old_api(
|
roundtripped_returns = convert_to_done_step_api(
|
||||||
(0, 0, terminated, truncated, info), is_vector_env=is_vector_env
|
(0, 0, terminated, truncated, info), is_vector_env=is_vector_env
|
||||||
)
|
)
|
||||||
assert data_equivalence(done_returns, roundtripped_returns)
|
assert data_equivalence(done_returns, roundtripped_returns)
|
||||||
@@ -112,7 +115,7 @@ def test_to_done_step_api(
|
|||||||
def test_to_terminated_truncated_step_api(
|
def test_to_terminated_truncated_step_api(
|
||||||
is_vector_env, terminated_truncated_returns, expected_done, expected_truncated
|
is_vector_env, terminated_truncated_returns, expected_done, expected_truncated
|
||||||
):
|
):
|
||||||
_, _, done, info = step_to_old_api(
|
_, _, done, info = convert_to_done_step_api(
|
||||||
terminated_truncated_returns, is_vector_env=is_vector_env
|
terminated_truncated_returns, is_vector_env=is_vector_env
|
||||||
)
|
)
|
||||||
assert np.all(done == expected_done)
|
assert np.all(done == expected_done)
|
||||||
@@ -136,7 +139,7 @@ def test_to_terminated_truncated_step_api(
|
|||||||
else:
|
else:
|
||||||
assert "TimeLimit.truncated" not in info
|
assert "TimeLimit.truncated" not in info
|
||||||
|
|
||||||
roundtripped_returns = step_to_new_api(
|
roundtripped_returns = convert_to_terminated_truncated_step_api(
|
||||||
(0, 0, done, info), is_vector_env=is_vector_env
|
(0, 0, done, info), is_vector_env=is_vector_env
|
||||||
)
|
)
|
||||||
assert data_equivalence(terminated_truncated_returns, roundtripped_returns)
|
assert data_equivalence(terminated_truncated_returns, roundtripped_returns)
|
||||||
@@ -146,19 +149,19 @@ def test_edge_case():
|
|||||||
# When converting between the two-step APIs this is not possible in a single case
|
# When converting between the two-step APIs this is not possible in a single case
|
||||||
# terminated=True and truncated=True -> done=True and info={}
|
# terminated=True and truncated=True -> done=True and info={}
|
||||||
# We cannot test this in test_to_terminated_truncated_step_api as the roundtripping test will fail
|
# We cannot test this in test_to_terminated_truncated_step_api as the roundtripping test will fail
|
||||||
_, _, done, info = step_to_old_api((0, 0, True, True, {}))
|
_, _, done, info = convert_to_done_step_api((0, 0, True, True, {}))
|
||||||
assert done is True
|
assert done is True
|
||||||
assert info == {"TimeLimit.truncated": False}
|
assert info == {"TimeLimit.truncated": False}
|
||||||
|
|
||||||
# Test with vector dict info
|
# Test with vector dict info
|
||||||
_, _, done, info = step_to_old_api(
|
_, _, done, info = convert_to_done_step_api(
|
||||||
(0, 0, np.array([True]), np.array([True]), {}), is_vector_env=True
|
(0, 0, np.array([True]), np.array([True]), {}), is_vector_env=True
|
||||||
)
|
)
|
||||||
assert np.all(done)
|
assert np.all(done)
|
||||||
assert info == {"TimeLimit.truncated": np.array([False])}
|
assert info == {"TimeLimit.truncated": np.array([False])}
|
||||||
|
|
||||||
# Test with vector list info
|
# Test with vector list info
|
||||||
_, _, done, info = step_to_old_api(
|
_, _, done, info = convert_to_done_step_api(
|
||||||
(0, 0, np.array([True]), np.array([True]), [{"Test-Info": True}]),
|
(0, 0, np.array([True]), np.array([True]), [{"Test-Info": True}]),
|
||||||
is_vector_env=True,
|
is_vector_env=True,
|
||||||
)
|
)
|
||||||
|
@@ -68,7 +68,7 @@ def test_step_async_vector_env(shared_memory, use_single_action_space):
|
|||||||
actions = [env.single_action_space.sample() for _ in range(8)]
|
actions = [env.single_action_space.sample() for _ in range(8)]
|
||||||
else:
|
else:
|
||||||
actions = env.action_space.sample()
|
actions = env.action_space.sample()
|
||||||
observations, rewards, dones, _ = env.step(actions)
|
observations, rewards, terminateds, truncateds, _ = env.step(actions)
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
@@ -83,10 +83,15 @@ def test_step_async_vector_env(shared_memory, use_single_action_space):
|
|||||||
assert rewards.ndim == 1
|
assert rewards.ndim == 1
|
||||||
assert rewards.size == 8
|
assert rewards.size == 8
|
||||||
|
|
||||||
assert isinstance(dones, np.ndarray)
|
assert isinstance(terminateds, np.ndarray)
|
||||||
assert dones.dtype == np.bool_
|
assert terminateds.dtype == np.bool_
|
||||||
assert dones.ndim == 1
|
assert terminateds.ndim == 1
|
||||||
assert dones.size == 8
|
assert terminateds.size == 8
|
||||||
|
|
||||||
|
assert isinstance(truncateds, np.ndarray)
|
||||||
|
assert truncateds.dtype == np.bool_
|
||||||
|
assert truncateds.ndim == 1
|
||||||
|
assert truncateds.size == 8
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||||
@@ -169,7 +174,7 @@ def test_step_timeout_async_vector_env(shared_memory):
|
|||||||
with pytest.raises(TimeoutError):
|
with pytest.raises(TimeoutError):
|
||||||
env.reset()
|
env.reset()
|
||||||
env.step_async(np.array([0.1, 0.1, 0.3, 0.1]))
|
env.step_async(np.array([0.1, 0.1, 0.3, 0.1]))
|
||||||
observations, rewards, dones, _ = env.step_wait(timeout=0.1)
|
observations, rewards, terminateds, truncateds, _ = env.step_wait(timeout=0.1)
|
||||||
env.close(terminate=True)
|
env.close(terminate=True)
|
||||||
|
|
||||||
|
|
||||||
@@ -262,7 +267,7 @@ def test_custom_space_async_vector_env():
|
|||||||
assert isinstance(env.action_space, Tuple)
|
assert isinstance(env.action_space, Tuple)
|
||||||
|
|
||||||
actions = ("action-2", "action-3", "action-5", "action-7")
|
actions = ("action-2", "action-3", "action-5", "action-7")
|
||||||
step_observations, rewards, dones, _ = env.step(actions)
|
step_observations, rewards, terminateds, truncateds, _ = env.step(actions)
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
@@ -1,88 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import gym
|
|
||||||
from gym.spaces import Discrete
|
|
||||||
from gym.vector import AsyncVectorEnv, SyncVectorEnv
|
|
||||||
|
|
||||||
|
|
||||||
class OldStepEnv(gym.Env):
|
|
||||||
def __init__(self):
|
|
||||||
self.action_space = Discrete(2)
|
|
||||||
self.observation_space = Discrete(2)
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
return 0, {}
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
obs = self.observation_space.sample()
|
|
||||||
rew = 0
|
|
||||||
done = False
|
|
||||||
info = {}
|
|
||||||
return obs, rew, done, info
|
|
||||||
|
|
||||||
|
|
||||||
class NewStepEnv(gym.Env):
|
|
||||||
def __init__(self):
|
|
||||||
self.action_space = Discrete(2)
|
|
||||||
self.observation_space = Discrete(2)
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
return 0, {}
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
obs = self.observation_space.sample()
|
|
||||||
rew = 0
|
|
||||||
terminated = False
|
|
||||||
truncated = False
|
|
||||||
info = {}
|
|
||||||
return obs, rew, terminated, truncated, info
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("VecEnv", [AsyncVectorEnv, SyncVectorEnv])
|
|
||||||
def test_vector_step_compatibility_new_env(VecEnv):
|
|
||||||
|
|
||||||
envs = [
|
|
||||||
OldStepEnv(),
|
|
||||||
NewStepEnv(),
|
|
||||||
]
|
|
||||||
|
|
||||||
vec_env = VecEnv([lambda: env for env in envs])
|
|
||||||
vec_env.reset()
|
|
||||||
step_returns = vec_env.step([0, 0])
|
|
||||||
assert len(step_returns) == 4
|
|
||||||
_, _, dones, _ = step_returns
|
|
||||||
assert dones.dtype == np.bool_
|
|
||||||
vec_env.close()
|
|
||||||
|
|
||||||
vec_env = VecEnv([lambda: env for env in envs], new_step_api=True)
|
|
||||||
vec_env.reset()
|
|
||||||
step_returns = vec_env.step([0, 0])
|
|
||||||
assert len(step_returns) == 5
|
|
||||||
_, _, terminateds, truncateds, _ = step_returns
|
|
||||||
assert terminateds.dtype == np.bool_
|
|
||||||
assert truncateds.dtype == np.bool_
|
|
||||||
vec_env.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("async_bool", [True, False])
|
|
||||||
def test_vector_step_compatibility_existing(async_bool):
|
|
||||||
|
|
||||||
env = gym.vector.make("CartPole-v1", num_envs=3, asynchronous=async_bool)
|
|
||||||
env.reset()
|
|
||||||
step_returns = env.step(env.action_space.sample())
|
|
||||||
assert len(step_returns) == 4
|
|
||||||
_, _, dones, _ = step_returns
|
|
||||||
assert dones.dtype == np.bool_
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
env = gym.vector.make(
|
|
||||||
"CartPole-v1", num_envs=3, asynchronous=async_bool, new_step_api=True
|
|
||||||
)
|
|
||||||
env.reset()
|
|
||||||
step_returns = env.step(env.action_space.sample())
|
|
||||||
assert len(step_returns) == 5
|
|
||||||
_, _, terminateds, truncateds, _ = step_returns
|
|
||||||
assert terminateds.dtype == np.bool_
|
|
||||||
assert truncateds.dtype == np.bool_
|
|
||||||
env.close()
|
|
@@ -50,7 +50,7 @@ def test_step_sync_vector_env(use_single_action_space):
|
|||||||
actions = [env.single_action_space.sample() for _ in range(8)]
|
actions = [env.single_action_space.sample() for _ in range(8)]
|
||||||
else:
|
else:
|
||||||
actions = env.action_space.sample()
|
actions = env.action_space.sample()
|
||||||
observations, rewards, dones, _ = env.step(actions)
|
observations, rewards, terminateds, truncateds, _ = env.step(actions)
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
@@ -65,10 +65,15 @@ def test_step_sync_vector_env(use_single_action_space):
|
|||||||
assert rewards.ndim == 1
|
assert rewards.ndim == 1
|
||||||
assert rewards.size == 8
|
assert rewards.size == 8
|
||||||
|
|
||||||
assert isinstance(dones, np.ndarray)
|
assert isinstance(terminateds, np.ndarray)
|
||||||
assert dones.dtype == np.bool_
|
assert terminateds.dtype == np.bool_
|
||||||
assert dones.ndim == 1
|
assert terminateds.ndim == 1
|
||||||
assert dones.size == 8
|
assert terminateds.size == 8
|
||||||
|
|
||||||
|
assert isinstance(truncateds, np.ndarray)
|
||||||
|
assert truncateds.dtype == np.bool_
|
||||||
|
assert truncateds.ndim == 1
|
||||||
|
assert truncateds.size == 8
|
||||||
|
|
||||||
|
|
||||||
def test_call_sync_vector_env():
|
def test_call_sync_vector_env():
|
||||||
@@ -125,7 +130,7 @@ def test_custom_space_sync_vector_env():
|
|||||||
assert isinstance(env.action_space, Tuple)
|
assert isinstance(env.action_space, Tuple)
|
||||||
|
|
||||||
actions = ("action-2", "action-3", "action-5", "action-7")
|
actions = ("action-2", "action-3", "action-5", "action-7")
|
||||||
step_observations, rewards, dones, _ = env.step(actions)
|
step_observations, rewards, terminateds, truncateds, _ = env.step(actions)
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
@@ -31,11 +31,11 @@ def test_vector_env_equal(shared_memory):
|
|||||||
assert actions in sync_env.action_space
|
assert actions in sync_env.action_space
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
async_observations, async_rewards, async_dones, async_infos = async_env.step(actions)
|
async_observations, async_rewards, async_terminateds, async_truncateds, async_infos = async_env.step(actions)
|
||||||
sync_observations, sync_rewards, sync_dones, sync_infos = sync_env.step(actions)
|
sync_observations, sync_rewards, sync_terminateds, sync_truncateds, sync_infos = sync_env.step(actions)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
if any(sync_dones):
|
if any(sync_terminateds) or any(sync_truncateds):
|
||||||
assert "final_observation" in async_infos
|
assert "final_observation" in async_infos
|
||||||
assert "_final_observation" in async_infos
|
assert "_final_observation" in async_infos
|
||||||
assert "final_observation" in sync_infos
|
assert "final_observation" in sync_infos
|
||||||
@@ -43,7 +43,8 @@ def test_vector_env_equal(shared_memory):
|
|||||||
|
|
||||||
assert np.all(async_observations == sync_observations)
|
assert np.all(async_observations == sync_observations)
|
||||||
assert np.all(async_rewards == sync_rewards)
|
assert np.all(async_rewards == sync_rewards)
|
||||||
assert np.all(async_dones == sync_dones)
|
assert np.all(async_terminateds == sync_terminateds)
|
||||||
|
assert np.all(async_truncateds == sync_truncateds)
|
||||||
|
|
||||||
async_env.close()
|
async_env.close()
|
||||||
sync_env.close()
|
sync_env.close()
|
||||||
|
@@ -20,16 +20,16 @@ def test_vector_env_info(asynchronous):
|
|||||||
for _ in range(ENV_STEPS):
|
for _ in range(ENV_STEPS):
|
||||||
env.action_space.seed(SEED)
|
env.action_space.seed(SEED)
|
||||||
action = env.action_space.sample()
|
action = env.action_space.sample()
|
||||||
_, _, dones, infos = env.step(action)
|
_, _, terminateds, truncateds, infos = env.step(action)
|
||||||
if any(dones):
|
if any(terminateds) or any(truncateds):
|
||||||
assert len(infos["final_observation"]) == NUM_ENVS
|
assert len(infos["final_observation"]) == NUM_ENVS
|
||||||
assert len(infos["_final_observation"]) == NUM_ENVS
|
assert len(infos["_final_observation"]) == NUM_ENVS
|
||||||
|
|
||||||
assert isinstance(infos["final_observation"], np.ndarray)
|
assert isinstance(infos["final_observation"], np.ndarray)
|
||||||
assert isinstance(infos["_final_observation"], np.ndarray)
|
assert isinstance(infos["_final_observation"], np.ndarray)
|
||||||
|
|
||||||
for i, done in enumerate(dones):
|
for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)):
|
||||||
if done:
|
if terminated or truncated:
|
||||||
assert infos["_final_observation"][i]
|
assert infos["_final_observation"][i]
|
||||||
else:
|
else:
|
||||||
assert not infos["_final_observation"][i]
|
assert not infos["_final_observation"][i]
|
||||||
@@ -44,11 +44,11 @@ def test_vector_env_info_concurrent_termination(concurrent_ends):
|
|||||||
envs = SyncVectorEnv(envs)
|
envs = SyncVectorEnv(envs)
|
||||||
|
|
||||||
for _ in range(ENV_STEPS):
|
for _ in range(ENV_STEPS):
|
||||||
_, _, dones, infos = envs.step(actions)
|
_, _, terminateds, truncateds, infos = envs.step(actions)
|
||||||
if any(dones):
|
if any(terminateds) or any(truncateds):
|
||||||
for i, done in enumerate(dones):
|
for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)):
|
||||||
if i < concurrent_ends:
|
if i < concurrent_ends:
|
||||||
assert done
|
assert terminated or truncated
|
||||||
assert infos["_final_observation"][i]
|
assert infos["_final_observation"][i]
|
||||||
else:
|
else:
|
||||||
assert not infos["_final_observation"][i]
|
assert not infos["_final_observation"][i]
|
||||||
|
@@ -68,8 +68,8 @@ class UnittestSlowEnv(gym.Env):
|
|||||||
def step(self, action):
|
def step(self, action):
|
||||||
time.sleep(action)
|
time.sleep(action)
|
||||||
observation = self.observation_space.sample()
|
observation = self.observation_space.sample()
|
||||||
reward, done = 0.0, False
|
reward, terminated, truncated = 0.0, False, False
|
||||||
return observation, reward, done, {}
|
return observation, reward, terminated, truncated, {}
|
||||||
|
|
||||||
|
|
||||||
class CustomSpace(gym.Space):
|
class CustomSpace(gym.Space):
|
||||||
@@ -103,8 +103,8 @@ class CustomSpaceEnv(gym.Env):
|
|||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
observation = f"step({action:s})"
|
observation = f"step({action:s})"
|
||||||
reward, done = 0.0, False
|
reward, terminated, truncated = 0.0, False, False
|
||||||
return observation, reward, done, {}
|
return observation, reward, terminated, truncated, {}
|
||||||
|
|
||||||
|
|
||||||
def make_env(env_name, seed, **kwargs):
|
def make_env(env_name, seed, **kwargs):
|
||||||
|
@@ -2,7 +2,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gym.spaces import Box, Discrete
|
from gym.spaces import Box, Discrete
|
||||||
from gym.wrappers import AtariPreprocessing
|
from gym.wrappers import AtariPreprocessing, StepAPICompatibility
|
||||||
from tests.testing_env import GenericTestEnv, old_step_fn
|
from tests.testing_env import GenericTestEnv, old_step_fn
|
||||||
|
|
||||||
|
|
||||||
@@ -49,7 +49,7 @@ class AtariTestingEnv(GenericTestEnv):
|
|||||||
(AtariTestingEnv(), (210, 160, 3)),
|
(AtariTestingEnv(), (210, 160, 3)),
|
||||||
(
|
(
|
||||||
AtariPreprocessing(
|
AtariPreprocessing(
|
||||||
AtariTestingEnv(),
|
StepAPICompatibility(AtariTestingEnv(), output_truncation_bool=True),
|
||||||
screen_size=84,
|
screen_size=84,
|
||||||
grayscale_obs=True,
|
grayscale_obs=True,
|
||||||
frame_skip=1,
|
frame_skip=1,
|
||||||
@@ -59,7 +59,7 @@ class AtariTestingEnv(GenericTestEnv):
|
|||||||
),
|
),
|
||||||
(
|
(
|
||||||
AtariPreprocessing(
|
AtariPreprocessing(
|
||||||
AtariTestingEnv(),
|
StepAPICompatibility(AtariTestingEnv(), output_truncation_bool=True),
|
||||||
screen_size=84,
|
screen_size=84,
|
||||||
grayscale_obs=False,
|
grayscale_obs=False,
|
||||||
frame_skip=1,
|
frame_skip=1,
|
||||||
@@ -69,7 +69,7 @@ class AtariTestingEnv(GenericTestEnv):
|
|||||||
),
|
),
|
||||||
(
|
(
|
||||||
AtariPreprocessing(
|
AtariPreprocessing(
|
||||||
AtariTestingEnv(),
|
StepAPICompatibility(AtariTestingEnv(), output_truncation_bool=True),
|
||||||
screen_size=84,
|
screen_size=84,
|
||||||
grayscale_obs=True,
|
grayscale_obs=True,
|
||||||
frame_skip=1,
|
frame_skip=1,
|
||||||
@@ -86,10 +86,14 @@ def test_atari_preprocessing_grayscale(env, obs_shape):
|
|||||||
# It is not possible to test the outputs as we are not using actual observations.
|
# It is not possible to test the outputs as we are not using actual observations.
|
||||||
# todo: update when ale-py is compatible with the ci
|
# todo: update when ale-py is compatible with the ci
|
||||||
|
|
||||||
|
env = StepAPICompatibility(
|
||||||
|
env, output_truncation_bool=True
|
||||||
|
) # using compatibility wrapper since ale-py uses old step API
|
||||||
|
|
||||||
obs, _ = env.reset(seed=0)
|
obs, _ = env.reset(seed=0)
|
||||||
assert obs in env.observation_space
|
assert obs in env.observation_space
|
||||||
|
|
||||||
obs, _, _, _ = env.step(env.action_space.sample())
|
obs, _, _, _, _ = env.step(env.action_space.sample())
|
||||||
assert obs in env.observation_space
|
assert obs in env.observation_space
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
@@ -100,7 +104,7 @@ def test_atari_preprocessing_grayscale(env, obs_shape):
|
|||||||
def test_atari_preprocessing_scale(grayscale, scaled, max_test_steps=10):
|
def test_atari_preprocessing_scale(grayscale, scaled, max_test_steps=10):
|
||||||
# arbitrarily chosen number for stepping into env. and ensuring all observations are in the required range
|
# arbitrarily chosen number for stepping into env. and ensuring all observations are in the required range
|
||||||
env = AtariPreprocessing(
|
env = AtariPreprocessing(
|
||||||
AtariTestingEnv(),
|
StepAPICompatibility(AtariTestingEnv(), output_truncation_bool=True),
|
||||||
screen_size=84,
|
screen_size=84,
|
||||||
grayscale_obs=grayscale,
|
grayscale_obs=grayscale,
|
||||||
scale_obs=scaled,
|
scale_obs=scaled,
|
||||||
@@ -113,9 +117,9 @@ def test_atari_preprocessing_scale(grayscale, scaled, max_test_steps=10):
|
|||||||
max_obs = 1 if scaled else 255
|
max_obs = 1 if scaled else 255
|
||||||
assert np.all(0 <= obs) and np.all(obs <= max_obs)
|
assert np.all(0 <= obs) and np.all(obs <= max_obs)
|
||||||
|
|
||||||
done, step_i = False, 0
|
terminated, truncated, step_i = False, False, 0
|
||||||
while not done and step_i <= max_test_steps:
|
while not (terminated or truncated) and step_i <= max_test_steps:
|
||||||
obs, _, done, _ = env.step(env.action_space.sample())
|
obs, _, terminated, truncated, _ = env.step(env.action_space.sample())
|
||||||
assert np.all(0 <= obs) and np.all(obs <= max_obs)
|
assert np.all(0 <= obs) and np.all(obs <= max_obs)
|
||||||
|
|
||||||
step_i += 1
|
step_i += 1
|
||||||
|
@@ -14,7 +14,7 @@ from tests.envs.utils import all_testing_env_specs
|
|||||||
class DummyResetEnv(gym.Env):
|
class DummyResetEnv(gym.Env):
|
||||||
"""A dummy environment which returns ascending numbers starting at `0` when :meth:`self.step()` is called.
|
"""A dummy environment which returns ascending numbers starting at `0` when :meth:`self.step()` is called.
|
||||||
|
|
||||||
After the second call to :meth:`self.step()` done is true.
|
After the second call to :meth:`self.step()` terminated is true.
|
||||||
Info dicts are also returned containing the same number returned as an observation, accessible via the key "count".
|
Info dicts are also returned containing the same number returned as an observation, accessible via the key "count".
|
||||||
This environment is provided for the purpose of testing the autoreset wrapper.
|
This environment is provided for the purpose of testing the autoreset wrapper.
|
||||||
"""
|
"""
|
||||||
@@ -30,12 +30,13 @@ class DummyResetEnv(gym.Env):
|
|||||||
self.count = 0
|
self.count = 0
|
||||||
|
|
||||||
def step(self, action: int):
|
def step(self, action: int):
|
||||||
"""Steps the DummyEnv with the incremented step, reward and done `if self.count > 1` and updated info."""
|
"""Steps the DummyEnv with the incremented step, reward and terminated `if self.count > 1` and updated info."""
|
||||||
self.count += 1
|
self.count += 1
|
||||||
return (
|
return (
|
||||||
np.array([self.count]), # Obs
|
np.array([self.count]), # Obs
|
||||||
self.count > 2, # Reward
|
self.count > 2, # Reward
|
||||||
self.count > 2, # Done
|
self.count > 2, # Terminated
|
||||||
|
False, # Truncated
|
||||||
{"count": self.count}, # Info
|
{"count": self.count}, # Info
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -70,9 +71,9 @@ def test_make_autoreset_true(spec):
|
|||||||
env.reset(seed=0)
|
env.reset(seed=0)
|
||||||
env.unwrapped.reset = MagicMock(side_effect=env.unwrapped.reset)
|
env.unwrapped.reset = MagicMock(side_effect=env.unwrapped.reset)
|
||||||
|
|
||||||
done = False
|
terminated, truncated = False, False
|
||||||
while not done:
|
while not (terminated or truncated):
|
||||||
obs, reward, done, info = env.step(env.action_space.sample())
|
obs, reward, terminated, truncated, info = env.step(env.action_space.sample())
|
||||||
|
|
||||||
assert env.unwrapped.reset.called
|
assert env.unwrapped.reset.called
|
||||||
env.close()
|
env.close()
|
||||||
@@ -109,33 +110,32 @@ def test_autoreset_wrapper_autoreset():
|
|||||||
assert info == {"count": 0}
|
assert info == {"count": 0}
|
||||||
|
|
||||||
action = 0
|
action = 0
|
||||||
obs, reward, done, info = env.step(action)
|
obs, reward, terminated, truncated, info = env.step(action)
|
||||||
assert obs == np.array([1])
|
assert obs == np.array([1])
|
||||||
assert reward == 0
|
assert reward == 0
|
||||||
assert done is False
|
assert (terminated or truncated) is False
|
||||||
assert info == {"count": 1}
|
assert info == {"count": 1}
|
||||||
|
|
||||||
obs, reward, done, info = env.step(action)
|
obs, reward, terminated, truncated, info = env.step(action)
|
||||||
assert obs == np.array([2])
|
assert obs == np.array([2])
|
||||||
assert done is False
|
assert (terminated or truncated) is False
|
||||||
assert reward == 0
|
assert reward == 0
|
||||||
assert info == {"count": 2}
|
assert info == {"count": 2}
|
||||||
|
|
||||||
obs, reward, done, info = env.step(action)
|
obs, reward, terminated, truncated, info = env.step(action)
|
||||||
assert obs == np.array([0])
|
assert obs == np.array([0])
|
||||||
assert done is True
|
assert (terminated or truncated) is True
|
||||||
assert reward == 1
|
assert reward == 1
|
||||||
assert info == {
|
assert info == {
|
||||||
"count": 0,
|
"count": 0,
|
||||||
"final_observation": np.array([3]),
|
"final_observation": np.array([3]),
|
||||||
"final_info": {"count": 3},
|
"final_info": {"count": 3},
|
||||||
"TimeLimit.truncated": False,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
obs, reward, done, info = env.step(action)
|
obs, reward, terminated, truncated, info = env.step(action)
|
||||||
assert obs == np.array([1])
|
assert obs == np.array([1])
|
||||||
assert reward == 0
|
assert reward == 0
|
||||||
assert done is False
|
assert (terminated or truncated) is False
|
||||||
assert info == {"count": 1}
|
assert info == {"count": 1}
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
@@ -18,10 +18,11 @@ def test_clip_action():
|
|||||||
|
|
||||||
actions = [[0.4], [1.2], [-0.3], [0.0], [-2.5]]
|
actions = [[0.4], [1.2], [-0.3], [0.0], [-2.5]]
|
||||||
for action in actions:
|
for action in actions:
|
||||||
obs1, r1, d1, _ = env.step(
|
obs1, r1, ter1, trunc1, _ = env.step(
|
||||||
np.clip(action, env.action_space.low, env.action_space.high)
|
np.clip(action, env.action_space.low, env.action_space.high)
|
||||||
)
|
)
|
||||||
obs2, r2, d2, _ = wrapped_env.step(action)
|
obs2, r2, ter2, trunc2, _ = wrapped_env.step(action)
|
||||||
assert np.allclose(r1, r2)
|
assert np.allclose(r1, r2)
|
||||||
assert np.allclose(obs1, obs2)
|
assert np.allclose(obs1, obs2)
|
||||||
assert d1 == d2
|
assert ter1 == ter2
|
||||||
|
assert trunc1 == trunc2
|
||||||
|
@@ -39,13 +39,14 @@ def test_frame_stack(env_id, num_stack, lz4_compress):
|
|||||||
|
|
||||||
for _ in range(num_stack**2):
|
for _ in range(num_stack**2):
|
||||||
action = env.action_space.sample()
|
action = env.action_space.sample()
|
||||||
dup_obs, _, dup_done, _ = dup.step(action)
|
dup_obs, _, dup_terminated, dup_truncated, _ = dup.step(action)
|
||||||
obs, _, done, _ = env.step(action)
|
obs, _, terminated, truncated, _ = env.step(action)
|
||||||
|
|
||||||
assert dup_done == done
|
assert dup_terminated == terminated
|
||||||
|
assert dup_truncated == truncated
|
||||||
assert np.allclose(obs[-1], dup_obs)
|
assert np.allclose(obs[-1], dup_obs)
|
||||||
|
|
||||||
if done:
|
if terminated or truncated:
|
||||||
break
|
break
|
||||||
|
|
||||||
assert len(obs) == num_stack
|
assert len(obs) == num_stack
|
||||||
|
@@ -15,8 +15,8 @@ def test_human_rendering():
|
|||||||
env.reset()
|
env.reset()
|
||||||
|
|
||||||
for _ in range(75):
|
for _ in range(75):
|
||||||
_, _, done, _ = env.step(env.action_space.sample())
|
_, _, terminated, truncated, _ = env.step(env.action_space.sample())
|
||||||
if done:
|
if terminated or truncated:
|
||||||
env.reset()
|
env.reset()
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
@@ -21,7 +21,13 @@ class DummyRewardEnv(gym.Env):
|
|||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
self.t += 1
|
self.t += 1
|
||||||
return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {}
|
return (
|
||||||
|
np.array([self.t]),
|
||||||
|
self.t,
|
||||||
|
self.t == len(self.returned_rewards),
|
||||||
|
False,
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
@@ -77,7 +83,7 @@ def test_normalize_observation_vector_env():
|
|||||||
env_fns = [make_env(0), make_env(1)]
|
env_fns = [make_env(0), make_env(1)]
|
||||||
envs = gym.vector.SyncVectorEnv(env_fns)
|
envs = gym.vector.SyncVectorEnv(env_fns)
|
||||||
envs.reset()
|
envs.reset()
|
||||||
obs, reward, _, _ = envs.step(envs.action_space.sample())
|
obs, reward, _, _, _ = envs.step(envs.action_space.sample())
|
||||||
np.testing.assert_almost_equal(obs, np.array([[1], [2]]), decimal=4)
|
np.testing.assert_almost_equal(obs, np.array([[1], [2]]), decimal=4)
|
||||||
np.testing.assert_almost_equal(reward, np.array([1, 2]), decimal=4)
|
np.testing.assert_almost_equal(reward, np.array([1, 2]), decimal=4)
|
||||||
|
|
||||||
@@ -90,7 +96,7 @@ def test_normalize_observation_vector_env():
|
|||||||
np.mean([0.5]), # the mean of first observations [[0, 1]]
|
np.mean([0.5]), # the mean of first observations [[0, 1]]
|
||||||
decimal=4,
|
decimal=4,
|
||||||
)
|
)
|
||||||
obs, reward, _, _ = envs.step(envs.action_space.sample())
|
obs, reward, _, _, _ = envs.step(envs.action_space.sample())
|
||||||
assert_almost_equal(
|
assert_almost_equal(
|
||||||
envs.obs_rms.mean,
|
envs.obs_rms.mean,
|
||||||
np.mean([1.0]), # the mean of first and second observations [[0, 1], [1, 2]]
|
np.mean([1.0]), # the mean of first and second observations [[0, 1], [1, 2]]
|
||||||
@@ -103,13 +109,13 @@ def test_normalize_return_vector_env():
|
|||||||
envs = gym.vector.SyncVectorEnv(env_fns)
|
envs = gym.vector.SyncVectorEnv(env_fns)
|
||||||
envs = NormalizeReward(envs)
|
envs = NormalizeReward(envs)
|
||||||
obs = envs.reset()
|
obs = envs.reset()
|
||||||
obs, reward, _, _ = envs.step(envs.action_space.sample())
|
obs, reward, _, _, _ = envs.step(envs.action_space.sample())
|
||||||
assert_almost_equal(
|
assert_almost_equal(
|
||||||
envs.return_rms.mean,
|
envs.return_rms.mean,
|
||||||
np.mean([1.5]), # the mean of first returns [[1, 2]]
|
np.mean([1.5]), # the mean of first returns [[1, 2]]
|
||||||
decimal=4,
|
decimal=4,
|
||||||
)
|
)
|
||||||
obs, reward, _, _ = envs.step(envs.action_space.sample())
|
obs, reward, _, _, _ = envs.step(envs.action_space.sample())
|
||||||
assert_almost_equal(
|
assert_almost_equal(
|
||||||
envs.return_rms.mean,
|
envs.return_rms.mean,
|
||||||
np.mean(
|
np.mean(
|
||||||
|
@@ -18,8 +18,8 @@ def test_record_episode_statistics(env_id, deque_size):
|
|||||||
assert env.episode_returns[0] == 0.0
|
assert env.episode_returns[0] == 0.0
|
||||||
assert env.episode_lengths[0] == 0
|
assert env.episode_lengths[0] == 0
|
||||||
for t in range(env.spec.max_episode_steps):
|
for t in range(env.spec.max_episode_steps):
|
||||||
_, _, done, info = env.step(env.action_space.sample())
|
_, _, terminated, truncated, info = env.step(env.action_space.sample())
|
||||||
if done:
|
if terminated or truncated:
|
||||||
assert "episode" in info
|
assert "episode" in info
|
||||||
assert all([item in info["episode"] for item in ["r", "l", "t"]])
|
assert all([item in info["episode"] for item in ["r", "l", "t"]])
|
||||||
break
|
break
|
||||||
@@ -55,11 +55,11 @@ def test_record_episode_statistics_with_vectorenv(num_envs, asynchronous):
|
|||||||
)
|
)
|
||||||
envs.reset()
|
envs.reset()
|
||||||
for _ in range(max_episode_step + 1):
|
for _ in range(max_episode_step + 1):
|
||||||
_, _, dones, infos = envs.step(envs.action_space.sample())
|
_, _, terminateds, truncateds, infos = envs.step(envs.action_space.sample())
|
||||||
if any(dones):
|
if any(terminateds) or any(truncateds):
|
||||||
assert "episode" in infos
|
assert "episode" in infos
|
||||||
assert "_episode" in infos
|
assert "_episode" in infos
|
||||||
assert all(infos["_episode"] == dones)
|
assert all(infos["_episode"] == np.bitwise_or(terminateds, truncateds))
|
||||||
assert all([item in infos["episode"] for item in ["r", "l", "t"]])
|
assert all([item in infos["episode"] for item in ["r", "l", "t"]])
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
@@ -11,8 +11,8 @@ def test_record_video_using_default_trigger():
|
|||||||
env.reset()
|
env.reset()
|
||||||
for _ in range(199):
|
for _ in range(199):
|
||||||
action = env.action_space.sample()
|
action = env.action_space.sample()
|
||||||
_, _, done, _ = env.step(action)
|
_, _, terminated, truncated, _ = env.step(action)
|
||||||
if done:
|
if terminated or truncated:
|
||||||
env.reset()
|
env.reset()
|
||||||
env.close()
|
env.close()
|
||||||
assert os.path.isdir("videos")
|
assert os.path.isdir("videos")
|
||||||
@@ -42,8 +42,8 @@ def test_record_video_step_trigger():
|
|||||||
env.reset()
|
env.reset()
|
||||||
for _ in range(199):
|
for _ in range(199):
|
||||||
action = env.action_space.sample()
|
action = env.action_space.sample()
|
||||||
_, _, done, _ = env.step(action)
|
_, _, terminated, truncated, _ = env.step(action)
|
||||||
if done:
|
if terminated or truncated:
|
||||||
env.reset()
|
env.reset()
|
||||||
env.close()
|
env.close()
|
||||||
assert os.path.isdir("videos")
|
assert os.path.isdir("videos")
|
||||||
@@ -72,7 +72,7 @@ def test_record_video_within_vector():
|
|||||||
envs = gym.wrappers.RecordEpisodeStatistics(envs)
|
envs = gym.wrappers.RecordEpisodeStatistics(envs)
|
||||||
envs.reset()
|
envs.reset()
|
||||||
for i in range(199):
|
for i in range(199):
|
||||||
_, _, _, infos = envs.step(envs.action_space.sample())
|
_, _, _, _, infos = envs.step(envs.action_space.sample())
|
||||||
|
|
||||||
# break when every env is done
|
# break when every env is done
|
||||||
if "episode" in infos and all(infos["_episode"]):
|
if "episode" in infos and all(infos["_episode"]):
|
||||||
|
@@ -22,10 +22,10 @@ def test_rescale_action():
|
|||||||
wrapped_obs, wrapped_obs_info = wrapped_env.reset(seed=seed)
|
wrapped_obs, wrapped_obs_info = wrapped_env.reset(seed=seed)
|
||||||
assert np.allclose(obs, wrapped_obs)
|
assert np.allclose(obs, wrapped_obs)
|
||||||
|
|
||||||
obs, reward, _, _ = env.step([1.5])
|
obs, reward, _, _, _ = env.step([1.5])
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
wrapped_env.step([1.5])
|
wrapped_env.step([1.5])
|
||||||
wrapped_obs, wrapped_reward, _, _ = wrapped_env.step([0.75])
|
wrapped_obs, wrapped_reward, _, _, _ = wrapped_env.step([0.75])
|
||||||
|
|
||||||
assert np.allclose(obs, wrapped_obs)
|
assert np.allclose(obs, wrapped_obs)
|
||||||
assert np.allclose(reward, wrapped_reward)
|
assert np.allclose(reward, wrapped_reward)
|
||||||
|
@@ -33,8 +33,12 @@ class NewStepEnv(gym.Env):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv])
|
@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv])
|
||||||
def test_step_compatibility_to_new_api(env):
|
@pytest.mark.parametrize("output_truncation_bool", [None, True])
|
||||||
env = StepAPICompatibility(env(), True)
|
def test_step_compatibility_to_new_api(env, output_truncation_bool):
|
||||||
|
if output_truncation_bool is None:
|
||||||
|
env = StepAPICompatibility(env())
|
||||||
|
else:
|
||||||
|
env = StepAPICompatibility(env(), output_truncation_bool)
|
||||||
step_returns = env.step(0)
|
step_returns = env.step(0)
|
||||||
_, _, terminated, truncated, _ = step_returns
|
_, _, terminated, truncated, _ = step_returns
|
||||||
assert isinstance(terminated, bool)
|
assert isinstance(terminated, bool)
|
||||||
@@ -42,31 +46,30 @@ def test_step_compatibility_to_new_api(env):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv])
|
@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv])
|
||||||
@pytest.mark.parametrize("new_step_api", [None, False])
|
def test_step_compatibility_to_old_api(env):
|
||||||
def test_step_compatibility_to_old_api(env, new_step_api):
|
env = StepAPICompatibility(env(), False)
|
||||||
if new_step_api is None:
|
|
||||||
env = StepAPICompatibility(env()) # default behavior is to retain old API
|
|
||||||
else:
|
|
||||||
env = StepAPICompatibility(env(), new_step_api)
|
|
||||||
step_returns = env.step(0)
|
step_returns = env.step(0)
|
||||||
assert len(step_returns) == 4
|
assert len(step_returns) == 4
|
||||||
_, _, done, _ = step_returns
|
_, _, done, _ = step_returns
|
||||||
assert isinstance(done, bool)
|
assert isinstance(done, bool)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("new_step_api", [None, True, False])
|
@pytest.mark.parametrize("apply_step_compatibility", [None, True, False])
|
||||||
def test_step_compatibility_in_make(new_step_api):
|
def test_step_compatibility_in_make(apply_step_compatibility):
|
||||||
if new_step_api is None:
|
gym.register("OldStepEnv-v0", entry_point=OldStepEnv)
|
||||||
with pytest.warns(
|
|
||||||
DeprecationWarning, match="Initializing environment in old step API"
|
if apply_step_compatibility is not None:
|
||||||
):
|
env = gym.make(
|
||||||
env = gym.make("CartPole-v1")
|
"OldStepEnv-v0",
|
||||||
else:
|
apply_step_compatibility=apply_step_compatibility,
|
||||||
env = gym.make("CartPole-v1", new_step_api=new_step_api)
|
disable_env_checker=True,
|
||||||
|
)
|
||||||
|
elif apply_step_compatibility is None:
|
||||||
|
env = gym.make("OldStepEnv-v0", disable_env_checker=True)
|
||||||
|
|
||||||
env.reset()
|
env.reset()
|
||||||
step_returns = env.step(0)
|
step_returns = env.step(0)
|
||||||
if new_step_api:
|
if apply_step_compatibility:
|
||||||
assert len(step_returns) == 5
|
assert len(step_returns) == 5
|
||||||
_, _, terminated, truncated, _ = step_returns
|
_, _, terminated, truncated, _ = step_returns
|
||||||
assert isinstance(terminated, bool)
|
assert isinstance(terminated, bool)
|
||||||
|
@@ -20,12 +20,12 @@ def test_time_aware_observation(env_id):
|
|||||||
assert wrapped_obs[-1] == 0.0
|
assert wrapped_obs[-1] == 0.0
|
||||||
assert wrapped_obs.shape[0] == obs.shape[0] + 1
|
assert wrapped_obs.shape[0] == obs.shape[0] + 1
|
||||||
|
|
||||||
wrapped_obs, _, _, _ = wrapped_env.step(env.action_space.sample())
|
wrapped_obs, _, _, _, _ = wrapped_env.step(env.action_space.sample())
|
||||||
assert wrapped_env.t == 1.0
|
assert wrapped_env.t == 1.0
|
||||||
assert wrapped_obs[-1] == 1.0
|
assert wrapped_obs[-1] == 1.0
|
||||||
assert wrapped_obs.shape[0] == obs.shape[0] + 1
|
assert wrapped_obs.shape[0] == obs.shape[0] + 1
|
||||||
|
|
||||||
wrapped_obs, _, _, _ = wrapped_env.step(env.action_space.sample())
|
wrapped_obs, _, _, _, _ = wrapped_env.step(env.action_space.sample())
|
||||||
assert wrapped_env.t == 2.0
|
assert wrapped_env.t == 2.0
|
||||||
assert wrapped_obs[-1] == 2.0
|
assert wrapped_obs[-1] == 2.0
|
||||||
assert wrapped_obs.shape[0] == obs.shape[0] + 1
|
assert wrapped_obs.shape[0] == obs.shape[0] + 1
|
||||||
|
@@ -22,31 +22,28 @@ def test_time_limit_wrapper(double_wrap):
|
|||||||
max_episode_length = 20
|
max_episode_length = 20
|
||||||
env = TimeLimit(env, max_episode_length)
|
env = TimeLimit(env, max_episode_length)
|
||||||
if double_wrap:
|
if double_wrap:
|
||||||
# TimeLimit wrapper should not overwrite
|
|
||||||
# the TimeLimit.truncated key
|
|
||||||
# if it was already set
|
|
||||||
env = TimeLimit(env, max_episode_length)
|
env = TimeLimit(env, max_episode_length)
|
||||||
env.reset()
|
env.reset()
|
||||||
done = False
|
terminated, truncated = False, False
|
||||||
n_steps = 0
|
n_steps = 0
|
||||||
info = {}
|
info = {}
|
||||||
while not done:
|
while not (terminated or truncated):
|
||||||
n_steps += 1
|
n_steps += 1
|
||||||
_, _, done, info = env.step(env.action_space.sample())
|
_, _, terminated, truncated, info = env.step(env.action_space.sample())
|
||||||
|
|
||||||
assert n_steps == max_episode_length
|
assert n_steps == max_episode_length
|
||||||
assert "TimeLimit.truncated" in info
|
assert truncated
|
||||||
assert info["TimeLimit.truncated"] is True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("double_wrap", [False, True])
|
@pytest.mark.parametrize("double_wrap", [False, True])
|
||||||
def test_termination_on_last_step(double_wrap):
|
def test_termination_on_last_step(double_wrap):
|
||||||
# Special case: termination at the last timestep
|
# Special case: termination at the last timestep
|
||||||
# but not due to timeout
|
# Truncation due to timeout also happens at the same step
|
||||||
|
|
||||||
env = PendulumEnv()
|
env = PendulumEnv()
|
||||||
|
|
||||||
def patched_step(_action):
|
def patched_step(_action):
|
||||||
return env.observation_space.sample(), 0.0, True, {}
|
return env.observation_space.sample(), 0.0, True, False, {}
|
||||||
|
|
||||||
env.step = patched_step
|
env.step = patched_step
|
||||||
|
|
||||||
@@ -55,7 +52,6 @@ def test_termination_on_last_step(double_wrap):
|
|||||||
if double_wrap:
|
if double_wrap:
|
||||||
env = TimeLimit(env, max_episode_length)
|
env = TimeLimit(env, max_episode_length)
|
||||||
env.reset()
|
env.reset()
|
||||||
_, _, done, info = env.step(env.action_space.sample())
|
_, _, terminated, truncated, _ = env.step(env.action_space.sample())
|
||||||
assert done is True
|
assert terminated is True
|
||||||
assert "TimeLimit.truncated" in info
|
assert truncated is True
|
||||||
assert info["TimeLimit.truncated"] is False
|
|
||||||
|
@@ -21,8 +21,15 @@ def test_transform_observation(env_id):
|
|||||||
assert isinstance(wrapped_obs_info, dict)
|
assert isinstance(wrapped_obs_info, dict)
|
||||||
|
|
||||||
action = env.action_space.sample()
|
action = env.action_space.sample()
|
||||||
obs, reward, done, _ = env.step(action)
|
obs, reward, terminated, truncated, _ = env.step(action)
|
||||||
wrapped_obs, wrapped_reward, wrapped_done, _ = wrapped_env.step(action)
|
(
|
||||||
|
wrapped_obs,
|
||||||
|
wrapped_reward,
|
||||||
|
wrapped_terminated,
|
||||||
|
wrapped_truncated,
|
||||||
|
_,
|
||||||
|
) = wrapped_env.step(action)
|
||||||
assert np.allclose(wrapped_obs, affine_transform(obs))
|
assert np.allclose(wrapped_obs, affine_transform(obs))
|
||||||
assert np.allclose(wrapped_reward, reward)
|
assert np.allclose(wrapped_reward, reward)
|
||||||
assert wrapped_done == done
|
assert wrapped_terminated == terminated
|
||||||
|
assert wrapped_truncated == truncated
|
||||||
|
@@ -19,8 +19,8 @@ def test_transform_reward(env_id):
|
|||||||
env.reset(seed=0)
|
env.reset(seed=0)
|
||||||
wrapped_env.reset(seed=0)
|
wrapped_env.reset(seed=0)
|
||||||
|
|
||||||
_, reward, _, _ = env.step(action)
|
_, reward, _, _, _ = env.step(action)
|
||||||
_, wrapped_reward, _, _ = wrapped_env.step(action)
|
_, wrapped_reward, _, _, _ = wrapped_env.step(action)
|
||||||
|
|
||||||
assert wrapped_reward == scale * reward
|
assert wrapped_reward == scale * reward
|
||||||
del env, wrapped_env
|
del env, wrapped_env
|
||||||
@@ -37,8 +37,8 @@ def test_transform_reward(env_id):
|
|||||||
env.reset(seed=0)
|
env.reset(seed=0)
|
||||||
wrapped_env.reset(seed=0)
|
wrapped_env.reset(seed=0)
|
||||||
|
|
||||||
_, reward, _, _ = env.step(action)
|
_, reward, _, _, _ = env.step(action)
|
||||||
_, wrapped_reward, _, _ = wrapped_env.step(action)
|
_, wrapped_reward, _, _, _ = wrapped_env.step(action)
|
||||||
|
|
||||||
assert abs(wrapped_reward) < abs(reward)
|
assert abs(wrapped_reward) < abs(reward)
|
||||||
assert wrapped_reward == -0.0005 or wrapped_reward == 0.0002
|
assert wrapped_reward == -0.0005 or wrapped_reward == 0.0002
|
||||||
@@ -55,8 +55,8 @@ def test_transform_reward(env_id):
|
|||||||
|
|
||||||
for _ in range(1000):
|
for _ in range(1000):
|
||||||
action = env.action_space.sample()
|
action = env.action_space.sample()
|
||||||
_, wrapped_reward, done, _ = wrapped_env.step(action)
|
_, wrapped_reward, terminated, truncated, _ = wrapped_env.step(action)
|
||||||
assert wrapped_reward in [-1.0, 0.0, 1.0]
|
assert wrapped_reward in [-1.0, 0.0, 1.0]
|
||||||
if done:
|
if terminated or truncated:
|
||||||
break
|
break
|
||||||
del env, wrapped_env
|
del env, wrapped_env
|
||||||
|
@@ -29,9 +29,9 @@ def test_info_to_list():
|
|||||||
|
|
||||||
for _ in range(ENV_STEPS):
|
for _ in range(ENV_STEPS):
|
||||||
action = wrapped_env.action_space.sample()
|
action = wrapped_env.action_space.sample()
|
||||||
_, _, dones, list_info = wrapped_env.step(action)
|
_, _, terminateds, truncateds, list_info = wrapped_env.step(action)
|
||||||
for i, done in enumerate(dones):
|
for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)):
|
||||||
if done:
|
if terminated or truncated:
|
||||||
assert "final_observation" in list_info[i]
|
assert "final_observation" in list_info[i]
|
||||||
else:
|
else:
|
||||||
assert "final_observation" not in list_info[i]
|
assert "final_observation" not in list_info[i]
|
||||||
@@ -47,9 +47,9 @@ def test_info_to_list_statistics():
|
|||||||
|
|
||||||
for _ in range(ENV_STEPS):
|
for _ in range(ENV_STEPS):
|
||||||
action = wrapped_env.action_space.sample()
|
action = wrapped_env.action_space.sample()
|
||||||
_, _, dones, list_info = wrapped_env.step(action)
|
_, _, terminateds, truncateds, list_info = wrapped_env.step(action)
|
||||||
for i, done in enumerate(dones):
|
for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)):
|
||||||
if done:
|
if terminated or truncated:
|
||||||
assert "episode" in list_info[i]
|
assert "episode" in list_info[i]
|
||||||
for stats in ["r", "l", "t"]:
|
for stats in ["r", "l", "t"]:
|
||||||
assert stats in list_info[i]["episode"]
|
assert stats in list_info[i]["episode"]
|
||||||
|
Reference in New Issue
Block a user