Removing return_info argument to env.reset() and deprecated env.seed() function (reset now always returns info) (#2962)

* removed return_info, made info dict mandatory in reset

* tenatively removed deprecated seed api for environments

* added more info type checks to wrapper tests

* formatting/style compliance

* addressed some comments

* polish to address review

* fixed tests after merge, and added a test of the return_info deprecation assertion if found in reset signature

* some organization of env_checker tests, reverted a probably merge error

* added deprecation check for seed function in env

* updated docstring

* removed debug prints, tweaked test_check_seed_deprecation

* changed return_info deprecation check from assertion to warning

* fixes to vector envs, now  should be correctly structured

* added some explanation and typehints for mockup depcreated return info reset function

* re-removed seed function from vector envs

* added explanation to _reset_return_info_type and changed the return statement
This commit is contained in:
John Balis
2022-08-23 11:09:54 -04:00
committed by GitHub
parent 1f864789fd
commit 3a8daafce1
56 changed files with 327 additions and 639 deletions

View File

@@ -23,14 +23,14 @@ The Gym API's API models environments as simple Python `env` classes. Creating e
```python ```python
import gym import gym
env = gym.make("CartPole-v1") env = gym.make("CartPole-v1")
observation, info = env.reset(seed=42, return_info=True) observation, info = env.reset(seed=42)
for _ in range(1000): for _ in range(1000):
action = env.action_space.sample() action = env.action_space.sample()
observation, reward, done, info = env.step(action) observation, reward, done, info = env.step(action)
if done: if done:
observation, info = env.reset(return_info=True) observation, info = env.reset()
env.close() env.close()
``` ```

View File

@@ -41,11 +41,10 @@ class Env(Generic[ObsType, ActType]):
The main API methods that users of this class need to know are: The main API methods that users of this class need to know are:
- :meth:`step` - Takes a step in the environment using an action returning the next observation, reward, - :meth:`step` - Takes a step in the environment using an action returning the next observation, reward,
if the environment terminated and more information. if the environment terminated and observation information.
- :meth:`reset` - Resets the environment to an initial state, returning the initial observation. - :meth:`reset` - Resets the environment to an initial state, returning the initial observation and observation information.
- :meth:`render` - Renders the environment observation with modes depending on the output - :meth:`render` - Renders the environment observation with modes depending on the output
- :meth:`close` - Closes the environment, important for rendering where pygame is imported - :meth:`close` - Closes the environment, important for rendering where pygame is imported
- :meth:`seed` - Seeds the environment's random number generator, :deprecated: in favor of `Env.reset(seed=seed)`.
And set the following attributes: And set the following attributes:
@@ -124,9 +123,8 @@ class Env(Generic[ObsType, ActType]):
self, self,
*, *,
seed: Optional[int] = None, seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
) -> Union[ObsType, Tuple[ObsType, dict]]: ) -> Tuple[ObsType, dict]:
"""Resets the environment to an initial state and returns the initial observation. """Resets the environment to an initial state and returns the initial observation.
This method can reset the environment's random number generator(s) if ``seed`` is an integer or This method can reset the environment's random number generator(s) if ``seed`` is an integer or
@@ -143,8 +141,6 @@ class Env(Generic[ObsType, ActType]):
If you pass an integer, the PRNG will be reset even if it already exists. If you pass an integer, the PRNG will be reset even if it already exists.
Usually, you want to pass an integer *right after the environment has been initialized and then never again*. Usually, you want to pass an integer *right after the environment has been initialized and then never again*.
Please refer to the minimal example above to see this paradigm in action. Please refer to the minimal example above to see this paradigm in action.
return_info (bool): If true, return additional information along with initial observation.
This info should be analogous to the info returned in :meth:`step`
options (optional dict): Additional information to specify how the environment is reset (optional, options (optional dict): Additional information to specify how the environment is reset (optional,
depending on the specific environment) depending on the specific environment)
@@ -152,8 +148,7 @@ class Env(Generic[ObsType, ActType]):
Returns: Returns:
observation (object): Observation of the initial state. This will be an element of :attr:`observation_space` observation (object): Observation of the initial state. This will be an element of :attr:`observation_space`
(typically a numpy array) and is analogous to the observation returned by :meth:`step`. (typically a numpy array) and is analogous to the observation returned by :meth:`step`.
info (optional dictionary): This will *only* be returned if ``return_info=True`` is passed. info (dictionary): This dictionary contains auxiliary information complementing ``observation``. It should be analogous to
It contains auxiliary information complementing ``observation``. This dictionary should be analogous to
the ``info`` returned by :meth:`step`. the ``info`` returned by :meth:`step`.
""" """
# Initialize the RNG if the seed is manually passed # Initialize the RNG if the seed is manually passed
@@ -193,33 +188,6 @@ class Env(Generic[ObsType, ActType]):
""" """
pass pass
def seed(self, seed=None):
""":deprecated: function that sets the seed for the environment's random number generator(s).
Use `env.reset(seed=seed)` as the new API for setting the seed of the environment.
Note:
Some environments use multiple pseudorandom number generators.
We want to capture all such seeds used in order to ensure that
there aren't accidental correlations between multiple generators.
Args:
seed(Optional int): The seed value for the random number generator
Returns:
seeds (List[int]): Returns the list of seeds used in this environment's random
number generators. The first value in the list should be the
"main" seed, or the value which a reproducer should pass to
'seed'. Often, the main seed equals the provided 'seed', but
this won't be true `if seed=None`, for example.
"""
deprecation(
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
"Please use `env.reset(seed=seed)` instead."
)
self._np_random, seed = seeding.np_random(seed)
return [seed]
@property @property
def unwrapped(self) -> "Env": def unwrapped(self) -> "Env":
"""Returns the base non-wrapped environment. """Returns the base non-wrapped environment.
@@ -370,7 +338,7 @@ class Wrapper(Env[ObsType, ActType]):
return step_api_compatibility(self.env.step(action), self.new_step_api) return step_api_compatibility(self.env.step(action), self.new_step_api)
def reset(self, **kwargs) -> Union[ObsType, Tuple[ObsType, dict]]: def reset(self, **kwargs) -> Tuple[ObsType, dict]:
"""Resets the environment with kwargs.""" """Resets the environment with kwargs."""
return self.env.reset(**kwargs) return self.env.reset(**kwargs)
@@ -384,10 +352,6 @@ class Wrapper(Env[ObsType, ActType]):
"""Closes the environment.""" """Closes the environment."""
return self.env.close() return self.env.close()
def seed(self, seed=None):
"""Seeds the environment."""
return self.env.seed(seed)
def __str__(self): def __str__(self):
"""Returns the wrapper name and the unwrapped environment string.""" """Returns the wrapper name and the unwrapped environment string."""
return f"<{type(self).__name__}{self.env}>" return f"<{type(self).__name__}{self.env}>"
@@ -432,11 +396,8 @@ class ObservationWrapper(Wrapper):
def reset(self, **kwargs): def reset(self, **kwargs):
"""Resets the environment, returning a modified observation using :meth:`self.observation`.""" """Resets the environment, returning a modified observation using :meth:`self.observation`."""
if kwargs.get("return_info", False): obs, info = self.env.reset(**kwargs)
obs, info = self.env.reset(**kwargs) return self.observation(obs), info
return self.observation(obs), info
else:
return self.observation(self.env.reset(**kwargs))
def step(self, action): def step(self, action):
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`.""" """Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""

View File

@@ -428,7 +428,6 @@ class BipedalWalker(gym.Env, EzPickle):
self, self,
*, *,
seed: Optional[int] = None, seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
super().reset(seed=seed) super().reset(seed=seed)
@@ -514,10 +513,7 @@ class BipedalWalker(gym.Env, EzPickle):
self.lidar = [LidarCallback() for _ in range(10)] self.lidar = [LidarCallback() for _ in range(10)]
self.renderer.reset() self.renderer.reset()
if not return_info: return self.step(np.array([0, 0, 0, 0]))[0], {}
return self.step(np.array([0, 0, 0, 0]))[0]
else:
return self.step(np.array([0, 0, 0, 0]))[0], {}
def step(self, action: np.ndarray): def step(self, action: np.ndarray):
assert self.hull is not None assert self.hull is not None

View File

@@ -483,7 +483,6 @@ class CarRacing(gym.Env, EzPickle):
self, self,
*, *,
seed: Optional[int] = None, seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
super().reset(seed=seed) super().reset(seed=seed)
@@ -519,10 +518,7 @@ class CarRacing(gym.Env, EzPickle):
self.car = Car(self.world, *self.track[0][1:4]) self.car = Car(self.world, *self.track[0][1:4])
self.renderer.reset() self.renderer.reset()
if not return_info: return self.step(None)[0], {}
return self.step(None)[0]
else:
return self.step(None)[0], {}
def step(self, action: Union[np.ndarray, int]): def step(self, action: Union[np.ndarray, int]):
assert self.car is not None assert self.car is not None

View File

@@ -305,7 +305,6 @@ class LunarLander(gym.Env, EzPickle):
self, self,
*, *,
seed: Optional[int] = None, seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
super().reset(seed=seed) super().reset(seed=seed)
@@ -413,10 +412,7 @@ class LunarLander(gym.Env, EzPickle):
self.drawlist = [self.lander] + self.legs self.drawlist = [self.lander] + self.legs
self.renderer.reset() self.renderer.reset()
if not return_info: return self.step(np.array([0, 0]) if self.continuous else 0)[0], {}
return self.step(np.array([0, 0]) if self.continuous else 0)[0]
else:
return self.step(np.array([0, 0]) if self.continuous else 0)[0], {}
def _create_particle(self, mass, x, y, ttl): def _create_particle(self, mass, x, y, ttl):
p = self.world.CreateDynamicBody( p = self.world.CreateDynamicBody(
@@ -774,7 +770,7 @@ def demo_heuristic_lander(env, seed=None, render=False):
total_reward = 0 total_reward = 0
steps = 0 steps = 0
s = env.reset(seed=seed) s, info = env.reset(seed=seed)
while True: while True:
a = heuristic(env, s) a = heuristic(env, s)
s, r, terminated, truncated, info = step_api_compatibility(env.step(a), True) s, r, terminated, truncated, info = step_api_compatibility(env.step(a), True)

View File

@@ -180,13 +180,7 @@ class AcrobotEnv(core.Env):
self.action_space = spaces.Discrete(3) self.action_space = spaces.Discrete(3)
self.state = None self.state = None
def reset( def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
super().reset(seed=seed) super().reset(seed=seed)
# Note that if you use custom reset bounds, it may lead to out-of-bound # Note that if you use custom reset bounds, it may lead to out-of-bound
# state/observations. # state/observations.
@@ -199,10 +193,7 @@ class AcrobotEnv(core.Env):
self.renderer.reset() self.renderer.reset()
self.renderer.render_step() self.renderer.render_step()
if not return_info: return self._get_ob(), {}
return self._get_ob()
else:
return self._get_ob(), {}
def step(self, a): def step(self, a):
s = self.state s = self.state

View File

@@ -192,7 +192,6 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
self, self,
*, *,
seed: Optional[int] = None, seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
super().reset(seed=seed) super().reset(seed=seed)
@@ -205,10 +204,7 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
self.steps_beyond_terminated = None self.steps_beyond_terminated = None
self.renderer.reset() self.renderer.reset()
self.renderer.render_step() self.renderer.render_step()
if not return_info: return np.array(self.state, dtype=np.float32), {}
return np.array(self.state, dtype=np.float32)
else:
return np.array(self.state, dtype=np.float32), {}
def render(self): def render(self):
return self.renderer.get_renders() return self.renderer.get_renders()

View File

@@ -174,13 +174,7 @@ class Continuous_MountainCarEnv(gym.Env):
self.renderer.render_step() self.renderer.render_step()
return self.state, reward, terminated, False, {} return self.state, reward, terminated, False, {}
def reset( def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
super().reset(seed=seed) super().reset(seed=seed)
# Note that if you use custom reset bounds, it may lead to out-of-bound # Note that if you use custom reset bounds, it may lead to out-of-bound
# state/observations. # state/observations.
@@ -188,10 +182,7 @@ class Continuous_MountainCarEnv(gym.Env):
self.state = np.array([self.np_random.uniform(low=low, high=high), 0]) self.state = np.array([self.np_random.uniform(low=low, high=high), 0])
self.renderer.reset() self.renderer.reset()
self.renderer.render_step() self.renderer.render_step()
if not return_info: return np.array(self.state, dtype=np.float32), {}
return np.array(self.state, dtype=np.float32)
else:
return np.array(self.state, dtype=np.float32), {}
def _height(self, xs): def _height(self, xs):
return np.sin(3 * xs) * 0.45 + 0.55 return np.sin(3 * xs) * 0.45 + 0.55

View File

@@ -152,7 +152,6 @@ class MountainCarEnv(gym.Env):
self, self,
*, *,
seed: Optional[int] = None, seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
super().reset(seed=seed) super().reset(seed=seed)
@@ -162,10 +161,7 @@ class MountainCarEnv(gym.Env):
self.state = np.array([self.np_random.uniform(low=low, high=high), 0]) self.state = np.array([self.np_random.uniform(low=low, high=high), 0])
self.renderer.reset() self.renderer.reset()
self.renderer.render_step() self.renderer.render_step()
if not return_info: return np.array(self.state, dtype=np.float32), {}
return np.array(self.state, dtype=np.float32)
else:
return np.array(self.state, dtype=np.float32), {}
def _height(self, xs): def _height(self, xs):
return np.sin(3 * xs) * 0.45 + 0.55 return np.sin(3 * xs) * 0.45 + 0.55

View File

@@ -138,13 +138,7 @@ class PendulumEnv(gym.Env):
self.renderer.render_step() self.renderer.render_step()
return self._get_obs(), -costs, False, False, {} return self._get_obs(), -costs, False, False, {}
def reset( def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
super().reset(seed=seed) super().reset(seed=seed)
if options is None: if options is None:
high = np.array([DEFAULT_X, DEFAULT_Y]) high = np.array([DEFAULT_X, DEFAULT_Y])
@@ -162,10 +156,7 @@ class PendulumEnv(gym.Env):
self.renderer.reset() self.renderer.reset()
self.renderer.render_step() self.renderer.render_step()
if not return_info: return self._get_obs(), {}
return self._get_obs()
else:
return self._get_obs(), {}
def _get_obs(self): def _get_obs(self):
theta, thetadot = self.state theta, thetadot = self.state

View File

@@ -142,7 +142,6 @@ class BaseMujocoEnv(gym.Env):
self, self,
*, *,
seed: Optional[int] = None, seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
super().reset(seed=seed) super().reset(seed=seed)
@@ -152,10 +151,7 @@ class BaseMujocoEnv(gym.Env):
ob = self.reset_model() ob = self.reset_model()
self.renderer.reset() self.renderer.reset()
self.renderer.render_step() self.renderer.render_step()
if not return_info: return ob, {}
return ob
else:
return ob, {}
def set_state(self, qpos, qvel): def set_state(self, qpos, qvel):
""" """

View File

@@ -167,7 +167,6 @@ class BlackjackEnv(gym.Env):
def reset( def reset(
self, self,
seed: Optional[int] = None, seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
super().reset(seed=seed) super().reset(seed=seed)
@@ -189,10 +188,7 @@ class BlackjackEnv(gym.Env):
self.renderer.reset() self.renderer.reset()
self.renderer.render_step() self.renderer.render_step()
if not return_info: return self._get_obs(), {}
return self._get_obs()
else:
return self._get_obs(), {}
def render(self): def render(self):
return self.renderer.get_renders() return self.renderer.get_renders()

View File

@@ -149,22 +149,14 @@ class CliffWalkingEnv(Env):
self.renderer.render_step() self.renderer.render_step()
return (int(s), r, t, False, {"prob": p}) return (int(s), r, t, False, {"prob": p})
def reset( def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
super().reset(seed=seed) super().reset(seed=seed)
self.s = categorical_sample(self.initial_state_distrib, self.np_random) self.s = categorical_sample(self.initial_state_distrib, self.np_random)
self.lastaction = None self.lastaction = None
self.renderer.reset() self.renderer.reset()
self.renderer.render_step() self.renderer.render_step()
if not return_info:
return int(self.s) return int(self.s), {"prob": 1}
else:
return int(self.s), {"prob": 1}
def render(self): def render(self):
return self.renderer.get_renders() return self.renderer.get_renders()

View File

@@ -256,7 +256,6 @@ class FrozenLakeEnv(Env):
self, self,
*, *,
seed: Optional[int] = None, seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
super().reset(seed=seed) super().reset(seed=seed)
@@ -266,10 +265,7 @@ class FrozenLakeEnv(Env):
self.renderer.reset() self.renderer.reset()
self.renderer.render_step() self.renderer.render_step()
if not return_info: return int(self.s), {"prob": 1}
return int(self.s)
else:
return int(self.s), {"prob": 1}
def render(self): def render(self):
return self.renderer.get_renders() return self.renderer.get_renders()

View File

@@ -89,7 +89,7 @@ class TaxiEnv(Env):
### Info ### Info
``step`` and ``reset(return_info=True)`` will return an info dictionary that contains "p" and "action_mask" containing ``step`` and ``reset()`` will return an info dictionary that contains "p" and "action_mask" containing
the probability that the state is taken and a mask of what actions will result in a change of state to speed up training. the probability that the state is taken and a mask of what actions will result in a change of state to speed up training.
As Taxi's initial state is a stochastic, the "p" key represents the probability of the As Taxi's initial state is a stochastic, the "p" key represents the probability of the
@@ -266,7 +266,6 @@ class TaxiEnv(Env):
self, self,
*, *,
seed: Optional[int] = None, seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
super().reset(seed=seed) super().reset(seed=seed)
@@ -275,10 +274,8 @@ class TaxiEnv(Env):
self.taxi_orientation = 0 self.taxi_orientation = 0
self.renderer.reset() self.renderer.reset()
self.renderer.render_step() self.renderer.render_step()
if not return_info:
return int(self.s) return int(self.s), {"prob": 1.0, "action_mask": self.action_mask(self.s)}
else:
return int(self.s), {"prob": 1.0, "action_mask": self.action_mask(self.s)}
def render(self): def render(self):
return self.renderer.get_renders() return self.renderer.get_renders()

View File

@@ -73,7 +73,7 @@ def check_reset_seed(env: gym.Env):
and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD
): ):
try: try:
obs_1 = env.reset(seed=123) obs_1, info = env.reset(seed=123)
assert ( assert (
obs_1 in env.observation_space obs_1 in env.observation_space
), "The observation returned by `env.reset(seed=123)` is not within the observation space." ), "The observation returned by `env.reset(seed=123)` is not within the observation space."
@@ -85,7 +85,7 @@ def check_reset_seed(env: gym.Env):
env.unwrapped._np_random # pyright: ignore [reportPrivateUsage] env.unwrapped._np_random # pyright: ignore [reportPrivateUsage]
) )
obs_2 = env.reset(seed=123) obs_2, info = env.reset(seed=123)
assert ( assert (
obs_2 in env.observation_space obs_2 in env.observation_space
), "The observation returned by `env.reset(seed=123)` is not within the observation space." ), "The observation returned by `env.reset(seed=123)` is not within the observation space."
@@ -98,7 +98,7 @@ def check_reset_seed(env: gym.Env):
== seed_123_rng.bit_generator.state == seed_123_rng.bit_generator.state
), "Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random generates are not same when the same seeds are passed to `env.reset`." ), "Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random generates are not same when the same seeds are passed to `env.reset`."
obs_3 = env.reset(seed=456) obs_3, info = env.reset(seed=456)
assert ( assert (
obs_3 in env.observation_space obs_3 in env.observation_space
), "The observation returned by `env.reset(seed=456)` is not within the observation space." ), "The observation returned by `env.reset(seed=456)` is not within the observation space."
@@ -126,53 +126,6 @@ def check_reset_seed(env: gym.Env):
) )
def check_reset_info(env: gym.Env):
"""Checks that :meth:`reset` supports the ``return_info`` keyword.
Args:
env: The environment to check
Raises:
AssertionError: The environment cannot be reset with `return_info=True`,
even though `return_info` or `kwargs` appear in the signature.
"""
signature = inspect.signature(env.reset)
if "return_info" in signature.parameters or (
"kwargs" in signature.parameters
and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD
):
try:
obs = env.reset(return_info=False)
assert (
obs in env.observation_space
), "The value returned by `env.reset(return_info=True)` is not within the observation space."
result = env.reset(return_info=True)
assert isinstance(
result, tuple
), f"Calling the reset method with `return_info=True` did not return a tuple, actual type: {type(result)}"
assert (
len(result) == 2
), f"Calling the reset method with `return_info=True` did not return a 2-tuple, actual length: {len(result)}"
obs, info = result
assert (
obs in env.observation_space
), "The first element returned by `env.reset(return_info=True)` is not within the observation space."
assert isinstance(
info, dict
), f"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: {type(info)}"
except TypeError as e:
raise AssertionError(
"The environment cannot be reset with `return_info=True`, even though `return_info` or `kwargs` appear in the signature. "
f"This should never happen, please report this issue. The error was: {e}"
)
else:
raise gym.error.Error(
"The `reset` method does not provide a `return_info` or `**kwargs` keyword argument."
)
def check_reset_options(env: gym.Env): def check_reset_options(env: gym.Env):
"""Check that the environment can be reset with options. """Check that the environment can be reset with options.
@@ -201,6 +154,64 @@ def check_reset_options(env: gym.Env):
) )
def check_reset_return_info_deprecation(env: gym.Env):
"""Makes sure support for deprecated `return_info` argument is dropped.
Args:
env: The environment to check
Raises:
UserWarning
"""
signature = inspect.signature(env.reset)
if "return_info" in signature.parameters:
logger.warn(
"`return_info` is deprecated as an optional argument to `reset`. `reset`"
"should now always return `obs, info` where `obs` is an observation, and `info` is a dictionary"
"containing additional information."
)
def check_seed_deprecation(env: gym.Env):
"""Makes sure support for deprecated function `seed` is dropped.
Args:
env: The environment to check
Raises:
UserWarning
"""
seed_fn = getattr(env, "seed", None)
if callable(seed_fn):
logger.warn(
"Official support for the `seed` function is dropped. "
"Standard practice is to reset gym environments using `env.reset(seed=<desired seed>)`"
)
def check_reset_return_type(env: gym.Env):
"""Checks that :meth:`reset` correctly returns a tuple of the form `(obs , info)`.
Args:
env: The environment to check
Raises:
AssertionError depending on spec violation
"""
result = env.reset()
assert isinstance(
result, tuple
), f"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `{type(result)}`"
assert (
len(result) == 2
), f"Calling the reset method did not return a 2-tuple, actual length: {len(result)}"
obs, info = result
assert (
obs in env.observation_space
), "The first element returned by `env.reset()` is not within the observation space."
assert isinstance(
info, dict
), f"The second element returned by `env.reset()` was not a dictionary, actual type: {type(info)}"
def check_space_limit(space, space_type: str): def check_space_limit(space, space_type: str):
"""Check the space limit for only the Box space as a test that only runs as part of `check_env`.""" """Check the space limit for only the Box space as a test that only runs as part of `check_env`."""
if isinstance(space, spaces.Box): if isinstance(space, spaces.Box):
@@ -279,9 +290,11 @@ def check_env(env: gym.Env, warn: bool = None, skip_render_check: bool = False):
check_space_limit(env.observation_space, "observation") check_space_limit(env.observation_space, "observation")
# ==== Check the reset method ==== # ==== Check the reset method ====
check_seed_deprecation(env)
check_reset_return_info_deprecation(env)
check_reset_return_type(env)
check_reset_seed(env) check_reset_seed(env)
check_reset_options(env) check_reset_options(env)
check_reset_info(env)
# ============ Check the returned values =============== # ============ Check the returned values ===============
env_reset_passive_checker(env) env_reset_passive_checker(env)

View File

@@ -183,14 +183,6 @@ def env_reset_passive_checker(env, **kwargs):
f"Actual default: {seed_param}" f"Actual default: {seed_param}"
) )
if "return_info" not in signature.parameters and not (
"kwargs" in signature.parameters
and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD
):
logger.warn(
"Future gym versions will require that `Env.reset` can be passed `return_info` to return information from the environment resetting."
)
if "options" not in signature.parameters and "kwargs" not in signature.parameters: if "options" not in signature.parameters and "kwargs" not in signature.parameters:
logger.warn( logger.warn(
"Future gym versions will require that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information." "Future gym versions will require that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information."
@@ -198,21 +190,17 @@ def env_reset_passive_checker(env, **kwargs):
# Checks the result of env.reset with kwargs # Checks the result of env.reset with kwargs
result = env.reset(**kwargs) result = env.reset(**kwargs)
if kwargs.get("return_info", False) is True:
assert isinstance(
result, tuple
), f"The result returned by `env.reset(return_info=True)` was not a tuple, actual type: {type(result)}"
assert (
len(result) == 2
), f"The length of the result returned by `env.reset(return_info=True)` is not 2, actual length: {len(result)}"
obs, info = result
assert isinstance(
info, dict
), f"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: {type(info)}"
else:
obs = result
if not isinstance(result, tuple):
logger.warn(
f"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `{type(result)}`"
)
obs, info = result
check_obs(obs, env.observation_space, "reset") check_obs(obs, env.observation_space, "reset")
assert isinstance(
info, dict
), f"The second element returned by `env.reset()` was not a dictionary, actual type: {type(info)}"
return result return result

View File

@@ -171,38 +171,9 @@ class AsyncVectorEnv(VectorEnv):
self._state = AsyncState.DEFAULT self._state = AsyncState.DEFAULT
self._check_spaces() self._check_spaces()
def seed(self, seed=None):
"""Seeds the vector environments.
Args:
seed: The seeds use with the environments
Raises:
AlreadyPendingCallError: Calling `seed` while waiting for a pending call to complete
"""
super().seed(seed=seed)
self._assert_is_running()
if seed is None:
seed = [None for _ in range(self.num_envs)]
if isinstance(seed, int):
seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
f"Calling `seed` while waiting for a pending call to `{self._state.value}` to complete.",
self._state.value,
)
for pipe, seed in zip(self.parent_pipes, seed):
pipe.send(("seed", seed))
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
def reset_async( def reset_async(
self, self,
seed: Optional[Union[int, List[int]]] = None, seed: Optional[Union[int, List[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
"""Send calls to the :obj:`reset` methods of the sub-environments. """Send calls to the :obj:`reset` methods of the sub-environments.
@@ -211,7 +182,6 @@ class AsyncVectorEnv(VectorEnv):
Args: Args:
seed: List of seeds for each environment seed: List of seeds for each environment
return_info: If to return information
options: The reset option options: The reset option
Raises: Raises:
@@ -238,8 +208,6 @@ class AsyncVectorEnv(VectorEnv):
single_kwargs = {} single_kwargs = {}
if single_seed is not None: if single_seed is not None:
single_kwargs["seed"] = single_seed single_kwargs["seed"] = single_seed
if return_info:
single_kwargs["return_info"] = return_info
if options is not None: if options is not None:
single_kwargs["options"] = options single_kwargs["options"] = options
@@ -250,7 +218,6 @@ class AsyncVectorEnv(VectorEnv):
self, self,
timeout: Optional[Union[int, float]] = None, timeout: Optional[Union[int, float]] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
) -> Union[ObsType, Tuple[ObsType, List[dict]]]: ) -> Union[ObsType, Tuple[ObsType, List[dict]]]:
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results. """Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
@@ -258,7 +225,6 @@ class AsyncVectorEnv(VectorEnv):
Args: Args:
timeout: Number of seconds before the call to `reset_wait` times out. If `None`, the call to `reset_wait` never times out. timeout: Number of seconds before the call to `reset_wait` times out. If `None`, the call to `reset_wait` never times out.
seed: ignored seed: ignored
return_info: If to return information
options: ignored options: ignored
Returns: Returns:
@@ -286,27 +252,17 @@ class AsyncVectorEnv(VectorEnv):
self._raise_if_errors(successes) self._raise_if_errors(successes)
self._state = AsyncState.DEFAULT self._state = AsyncState.DEFAULT
if return_info: infos = {}
infos = {} results, info_data = zip(*results)
results, info_data = zip(*results) for i, info in enumerate(info_data):
for i, info in enumerate(info_data): infos = self._add_info(infos, info, i)
infos = self._add_info(infos, info, i)
if not self.shared_memory: if not self.shared_memory:
self.observations = concatenate( self.observations = concatenate(
self.single_observation_space, results, self.observations self.single_observation_space, results, self.observations
) )
return ( return (deepcopy(self.observations) if self.copy else self.observations), infos
deepcopy(self.observations) if self.copy else self.observations
), infos
else:
if not self.shared_memory:
self.observations = concatenate(
self.single_observation_space, results, self.observations
)
return deepcopy(self.observations) if self.copy else self.observations
def step_async(self, actions: np.ndarray): def step_async(self, actions: np.ndarray):
"""Send the calls to :obj:`step` to each sub-environment. """Send the calls to :obj:`step` to each sub-environment.
@@ -606,12 +562,8 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
while True: while True:
command, data = pipe.recv() command, data = pipe.recv()
if command == "reset": if command == "reset":
if "return_info" in data and data["return_info"] is True: observation, info = env.reset(**data)
observation, info = env.reset(**data) pipe.send(((observation, info), True))
pipe.send(((observation, info), True))
else:
observation = env.reset(**data)
pipe.send((observation, True))
elif command == "step": elif command == "step":
( (
@@ -622,8 +574,9 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
info, info,
) = step_api_compatibility(env.step(data), True) ) = step_api_compatibility(env.step(data), True)
if terminated or truncated: if terminated or truncated:
info["final_observation"] = observation old_observation = observation
observation = env.reset() observation, info = env.reset()
info["final_observation"] = old_observation
pipe.send(((observation, reward, terminated, truncated, info), True)) pipe.send(((observation, reward, terminated, truncated, info), True))
elif command == "seed": elif command == "seed":
env.seed(data) env.seed(data)
@@ -676,18 +629,12 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
while True: while True:
command, data = pipe.recv() command, data = pipe.recv()
if command == "reset": if command == "reset":
if "return_info" in data and data["return_info"] is True: observation, info = env.reset(**data)
observation, info = env.reset(**data) write_to_shared_memory(
write_to_shared_memory( observation_space, index, observation, shared_memory
observation_space, index, observation, shared_memory )
) pipe.send(((None, info), True))
pipe.send(((None, info), True))
else:
observation = env.reset(**data)
write_to_shared_memory(
observation_space, index, observation, shared_memory
)
pipe.send((None, True))
elif command == "step": elif command == "step":
( (
observation, observation,
@@ -697,8 +644,10 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
info, info,
) = step_api_compatibility(env.step(data), True) ) = step_api_compatibility(env.step(data), True)
if terminated or truncated: if terminated or truncated:
info["final_observation"] = observation old_observation = observation
observation = env.reset() observation, info = env.reset()
info["final_observation"] = old_observation
write_to_shared_memory( write_to_shared_memory(
observation_space, index, observation, shared_memory observation_space, index, observation, shared_memory
) )

View File

@@ -93,14 +93,12 @@ class SyncVectorEnv(VectorEnv):
def reset_wait( def reset_wait(
self, self,
seed: Optional[Union[int, List[int]]] = None, seed: Optional[Union[int, List[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results. """Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
Args: Args:
seed: The reset environment seed seed: The reset environment seed
return_info: If to return information
options: Option information for the environment reset options: Option information for the environment reset
Returns: Returns:
@@ -123,26 +121,15 @@ class SyncVectorEnv(VectorEnv):
kwargs["seed"] = single_seed kwargs["seed"] = single_seed
if options is not None: if options is not None:
kwargs["options"] = options kwargs["options"] = options
if return_info is True:
kwargs["return_info"] = return_info
if not return_info: observation, info = env.reset(**kwargs)
observation = env.reset(**kwargs) observations.append(observation)
observations.append(observation) infos = self._add_info(infos, info, i)
else:
observation, info = env.reset(**kwargs)
observations.append(observation)
infos = self._add_info(infos, info, i)
self.observations = concatenate( self.observations = concatenate(
self.single_observation_space, observations, self.observations self.single_observation_space, observations, self.observations
) )
if not return_info: return (deepcopy(self.observations) if self.copy else self.observations), infos
return deepcopy(self.observations) if self.copy else self.observations
else:
return (
deepcopy(self.observations) if self.copy else self.observations
), infos
def step_async(self, actions): def step_async(self, actions):
"""Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version.""" """Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version."""
@@ -164,8 +151,9 @@ class SyncVectorEnv(VectorEnv):
info, info,
) = step_api_compatibility(env.step(action), True) ) = step_api_compatibility(env.step(action), True)
if self._terminateds[i] or self._truncateds[i]: if self._terminateds[i] or self._truncateds[i]:
info["final_observation"] = observation old_observation = observation
observation = env.reset() observation, info = env.reset()
info["final_observation"] = old_observation
observations.append(observation) observations.append(observation)
infos = self._add_info(infos, info, i) infos = self._add_info(infos, info, i)
self.observations = concatenate( self.observations = concatenate(

View File

@@ -60,7 +60,6 @@ class VectorEnv(gym.Env):
def reset_async( def reset_async(
self, self,
seed: Optional[Union[int, List[int]]] = None, seed: Optional[Union[int, List[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
"""Reset the sub-environments asynchronously. """Reset the sub-environments asynchronously.
@@ -70,7 +69,6 @@ class VectorEnv(gym.Env):
Args: Args:
seed: The reset seed seed: The reset seed
return_info: If to return info
options: Reset options options: Reset options
""" """
pass pass
@@ -78,7 +76,6 @@ class VectorEnv(gym.Env):
def reset_wait( def reset_wait(
self, self,
seed: Optional[Union[int, List[int]]] = None, seed: Optional[Union[int, List[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
"""Retrieves the results of a :meth:`reset_async` call. """Retrieves the results of a :meth:`reset_async` call.
@@ -87,7 +84,6 @@ class VectorEnv(gym.Env):
Args: Args:
seed: The reset seed seed: The reset seed
return_info: If to return info
options: Reset options options: Reset options
Returns: Returns:
@@ -102,21 +98,19 @@ class VectorEnv(gym.Env):
self, self,
*, *,
seed: Optional[Union[int, List[int]]] = None, seed: Optional[Union[int, List[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
"""Reset all parallel environments and return a batch of initial observations. """Reset all parallel environments and return a batch of initial observations.
Args: Args:
seed: The environment reset seeds seed: The environment reset seeds
return_info: If to return the info
options: If to return the options options: If to return the options
Returns: Returns:
A batch of observations from the vectorized environment. A batch of observations from the vectorized environment.
""" """
self.reset_async(seed=seed, return_info=return_info, options=options) self.reset_async(seed=seed, options=options)
return self.reset_wait(seed=seed, return_info=return_info, options=options) return self.reset_wait(seed=seed, options=options)
def step_async(self, actions): def step_async(self, actions):
"""Asynchronously performs steps in the sub-environments. """Asynchronously performs steps in the sub-environments.
@@ -220,21 +214,6 @@ class VectorEnv(gym.Env):
self.close_extras(**kwargs) self.close_extras(**kwargs)
self.closed = True self.closed = True
def seed(self, seed=None):
"""Set the random seed in all parallel environments.
Args:
seed: Random seed for each parallel environment. If ``seed`` is a list of
length ``num_envs``, then the items of the list are chosen as random
seeds. If ``seed`` is an int, then each parallel environment uses the random
seed ``seed + n``, where ``n`` is the index of the parallel environment
(between ``0`` and ``num_envs - 1``).
"""
deprecation(
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
"Please use `env.reset(seed=seed) instead in VectorEnvs."
)
def _add_info(self, infos: dict, info: dict, env_num: int) -> dict: def _add_info(self, infos: dict, info: dict, env_num: int) -> dict:
"""Add env info to the info dictionary of the vectorized environment. """Add env info to the info dictionary of the vectorized environment.
@@ -339,9 +318,6 @@ class VectorEnvWrapper(VectorEnv):
def close_extras(self, **kwargs): def close_extras(self, **kwargs):
return self.env.close_extras(**kwargs) return self.env.close_extras(**kwargs)
def seed(self, seed=None):
return self.env.seed(seed)
def call(self, name, *args, **kwargs): def call(self, name, *args, **kwargs):
return self.env.call(name, *args, **kwargs) return self.env.call(name, *args, **kwargs)

View File

@@ -151,11 +151,7 @@ class AtariPreprocessing(gym.Wrapper):
def reset(self, **kwargs): def reset(self, **kwargs):
"""Resets the environment using preprocessing.""" """Resets the environment using preprocessing."""
# NoopReset # NoopReset
if kwargs.get("return_info", False): _, reset_info = self.env.reset(**kwargs)
_, reset_info = self.env.reset(**kwargs)
else:
_ = self.env.reset(**kwargs)
reset_info = {}
noops = ( noops = (
self.env.unwrapped.np_random.integers(1, self.noop_max + 1) self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
@@ -168,11 +164,7 @@ class AtariPreprocessing(gym.Wrapper):
) )
reset_info.update(step_info) reset_info.update(step_info)
if terminated or truncated: if terminated or truncated:
if kwargs.get("return_info", False): _, reset_info = self.env.reset(**kwargs)
_, reset_info = self.env.reset(**kwargs)
else:
_ = self.env.reset(**kwargs)
reset_info = {}
self.lives = self.ale.lives() self.lives = self.ale.lives()
if self.grayscale_obs: if self.grayscale_obs:
@@ -181,10 +173,7 @@ class AtariPreprocessing(gym.Wrapper):
self.ale.getScreenRGB(self.obs_buffer[0]) self.ale.getScreenRGB(self.obs_buffer[0])
self.obs_buffer[1].fill(0) self.obs_buffer[1].fill(0)
if kwargs.get("return_info", False): return self._get_obs(), reset_info
return self._get_obs(), reset_info
else:
return self._get_obs()
def _get_obs(self): def _get_obs(self):
if self.frame_skip > 1: # more efficient in-place pooling if self.frame_skip > 1: # more efficient in-place pooling

View File

@@ -48,7 +48,7 @@ class AutoResetWrapper(gym.Wrapper):
if terminated or truncated: if terminated or truncated:
new_obs, new_info = self.env.reset(return_info=True) new_obs, new_info = self.env.reset()
assert ( assert (
"final_observation" not in new_info "final_observation" not in new_info
), 'info dict cannot contain key "final_observation" ' ), 'info dict cannot contain key "final_observation" '

View File

@@ -191,14 +191,8 @@ class FrameStack(gym.ObservationWrapper):
Returns: Returns:
The stacked observations The stacked observations
""" """
if kwargs.get("return_info", False): obs, info = self.env.reset(**kwargs)
obs, info = self.env.reset(**kwargs)
else:
obs = self.env.reset(**kwargs)
info = None # Unused
[self.frames.append(obs) for _ in range(self.num_stack)] [self.frames.append(obs) for _ in range(self.num_stack)]
if kwargs.get("return_info", False): return self.observation(None), info
return self.observation(None), info
else:
return self.observation(None)

View File

@@ -89,20 +89,12 @@ class NormalizeObservation(gym.core.Wrapper):
def reset(self, **kwargs): def reset(self, **kwargs):
"""Resets the environment and normalizes the observation.""" """Resets the environment and normalizes the observation."""
if kwargs.get("return_info", False): obs, info = self.env.reset(**kwargs)
obs, info = self.env.reset(**kwargs)
if self.is_vector_env: if self.is_vector_env:
return self.normalize(obs), info return self.normalize(obs), info
else:
return self.normalize(np.array([obs]))[0], info
else: else:
obs = self.env.reset(**kwargs) return self.normalize(np.array([obs]))[0], info
if self.is_vector_env:
return self.normalize(obs)
else:
return self.normalize(np.array([obs]))[0]
def normalize(self, obs): def normalize(self, obs):
"""Normalises the observation using the running mean and variance of the observations.""" """Normalises the observation using the running mean and variance of the observations."""

View File

@@ -57,9 +57,6 @@ class VectorListInfo(gym.Wrapper):
def reset(self, **kwargs): def reset(self, **kwargs):
"""Resets the environment using kwargs.""" """Resets the environment using kwargs."""
if not kwargs.get("return_info"):
return self.env.reset(**kwargs)
obs, infos = self.env.reset(**kwargs) obs, infos = self.env.reset(**kwargs)
list_info = self._convert_info_to_list(infos) list_info = self._convert_info_to_list(infos)
return obs, list_info return obs, list_info

View File

@@ -139,7 +139,7 @@ def test_taxi_action_mask():
def test_taxi_encode_decode(): def test_taxi_encode_decode():
env = TaxiEnv() env = TaxiEnv()
state = env.reset() state, info = env.reset()
for _ in range(100): for _ in range(100):
assert ( assert (
env.encode(*env.decode(state)) == state env.encode(*env.decode(state)) == state

View File

@@ -82,8 +82,8 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
env_1 = env_spec.make(disable_env_checker=True) env_1 = env_spec.make(disable_env_checker=True)
env_2 = env_spec.make(disable_env_checker=True) env_2 = env_spec.make(disable_env_checker=True)
initial_obs_1 = env_1.reset(seed=SEED) initial_obs_1, initial_info_1 = env_1.reset(seed=SEED)
initial_obs_2 = env_2.reset(seed=SEED) initial_obs_2, initial_info_2 = env_2.reset(seed=SEED)
assert_equals(initial_obs_1, initial_obs_2) assert_equals(initial_obs_1, initial_obs_2)
env_1.action_space.seed(SEED) env_1.action_space.seed(SEED)

View File

@@ -17,8 +17,8 @@ def verify_environments_match(
old_env = envs.make(old_env_id, disable_env_checker=True) old_env = envs.make(old_env_id, disable_env_checker=True)
new_env = envs.make(new_env_id, disable_env_checker=True) new_env = envs.make(new_env_id, disable_env_checker=True)
old_reset_obs = old_env.reset(seed=seed) old_reset_obs, old_info = old_env.reset(seed=seed)
new_reset_obs = new_env.reset(seed=seed) new_reset_obs, new_info = new_env.reset(seed=seed)
np.testing.assert_allclose(old_reset_obs, new_reset_obs) np.testing.assert_allclose(old_reset_obs, new_reset_obs)
@@ -56,7 +56,7 @@ EXCLUDE_POS_FROM_OBS = [
def test_obs_space_mujoco_environments(env_spec: EnvSpec): def test_obs_space_mujoco_environments(env_spec: EnvSpec):
"""Check that the returned observations are contained in the observation space of the environment""" """Check that the returned observations are contained in the observation space of the environment"""
env = env_spec.make(disable_env_checker=True) env = env_spec.make(disable_env_checker=True)
reset_obs = env.reset() reset_obs, info = env.reset()
assert env.observation_space.contains( assert env.observation_space.contains(
reset_obs reset_obs
), f"Obseravtion returned by reset() of {env_spec.id} is not contained in the default observation space {env.observation_space}." ), f"Obseravtion returned by reset() of {env_spec.id} is not contained in the default observation space {env.observation_space}."
@@ -73,7 +73,7 @@ def test_obs_space_mujoco_environments(env_spec: EnvSpec):
env = env_spec.make( env = env_spec.make(
disable_env_checker=True, exclude_current_positions_from_observation=False disable_env_checker=True, exclude_current_positions_from_observation=False
) )
reset_obs = env.reset() reset_obs, info = env.reset()
assert env.observation_space.contains( assert env.observation_space.contains(
reset_obs reset_obs
), f"Obseravtion of {env_spec.id} is not contained in the default observation space {env.observation_space} when excluding current position from observation." ), f"Obseravtion of {env_spec.id} is not contained in the default observation space {env.observation_space} when excluding current position from observation."
@@ -86,7 +86,7 @@ def test_obs_space_mujoco_environments(env_spec: EnvSpec):
# Ant-v4 has the option of including contact forces in the observation space with the use_contact_forces argument # Ant-v4 has the option of including contact forces in the observation space with the use_contact_forces argument
if env_spec.name == "Ant" and env_spec.version == 4: if env_spec.name == "Ant" and env_spec.version == 4:
env = env_spec.make(disable_env_checker=True, use_contact_forces=True) env = env_spec.make(disable_env_checker=True, use_contact_forces=True)
reset_obs = env.reset() reset_obs, info = env.reset()
assert env.observation_space.contains( assert env.observation_space.contains(
reset_obs reset_obs
), f"Obseravtion of {env_spec.id} is not contained in the default observation space {env.observation_space} when using contact forces." ), f"Obseravtion of {env_spec.id} is not contained in the default observation space {env.observation_space} when using contact forces."

View File

@@ -21,17 +21,9 @@ class UnittestEnv(core.Env):
observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8) observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8)
action_space = spaces.Discrete(3) action_space = spaces.Discrete(3)
def reset( def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
super().reset(seed=seed) super().reset(seed=seed)
if return_info: return self.observation_space.sample(), {"info": "dummy"}
return self.observation_space.sample(), {"info": "dummy"}
return self.observation_space.sample() # Dummy observation
def step(self, action): def step(self, action):
observation = self.observation_space.sample() # Dummy observation observation = self.observation_space.sample() # Dummy observation
@@ -45,22 +37,13 @@ class UnknownSpacesEnv(core.Env):
on external resources), it is not encouraged. on external resources), it is not encouraged.
""" """
def reset( def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
super().reset(seed=seed) super().reset(seed=seed)
self.observation_space = spaces.Box( self.observation_space = spaces.Box(
low=0, high=255, shape=(64, 64, 3), dtype=np.uint8 low=0, high=255, shape=(64, 64, 3), dtype=np.uint8
) )
self.action_space = spaces.Discrete(3) self.action_space = spaces.Discrete(3)
if not return_info: return self.observation_space.sample(), {} # Dummy observation with info
return self.observation_space.sample() # Dummy observation
else:
return self.observation_space.sample(), {} # Dummy observation with info
def step(self, action): def step(self, action):
observation = self.observation_space.sample() # Dummy observation observation = self.observation_space.sample() # Dummy observation

View File

@@ -12,16 +12,12 @@ def basic_reset_fn(
self, self,
*, *,
seed: Optional[int] = None, seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
) -> Union[ObsType, Tuple[ObsType, dict]]: ) -> Union[ObsType, Tuple[ObsType, dict]]:
"""A basic reset function that will pass the environment check using random actions from the observation space.""" """A basic reset function that will pass the environment check using random actions from the observation space."""
super(GenericTestEnv, self).reset(seed=seed) super(GenericTestEnv, self).reset(seed=seed)
self.observation_space.seed(seed) self.observation_space.seed(seed)
if return_info: return self.observation_space.sample(), {"options": options}
return self.observation_space.sample(), {"options": options}
else:
return self.observation_space.sample()
def new_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: def new_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
@@ -77,7 +73,6 @@ class GenericTestEnv(gym.Env):
self, self,
*, *,
seed: Optional[int] = None, seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
) -> Union[ObsType, Tuple[ObsType, dict]]: ) -> Union[ObsType, Tuple[ObsType, dict]]:
# If you need a default working reset function, use `basic_reset_fn` above # If you need a default working reset function, use `basic_reset_fn` above

View File

@@ -1,16 +1,21 @@
"""Tests that the `env_checker` runs as expects and all errors are possible.""" """Tests that the `env_checker` runs as expects and all errors are possible."""
import re import re
import warnings
from typing import Tuple, Union
import numpy as np import numpy as np
import pytest import pytest
import gym import gym
from gym import spaces from gym import spaces
from gym.core import ObsType
from gym.utils.env_checker import ( from gym.utils.env_checker import (
check_env, check_env,
check_reset_info,
check_reset_options, check_reset_options,
check_reset_return_info_deprecation,
check_reset_return_type,
check_reset_seed, check_reset_seed,
check_seed_deprecation,
) )
from tests.testing_env import GenericTestEnv from tests.testing_env import GenericTestEnv
@@ -48,29 +53,27 @@ def test_no_error_warnings(env):
assert len(warnings) == 0, [warning.message for warning in warnings] assert len(warnings) == 0, [warning.message for warning in warnings]
def _no_super_reset(self, seed=None, return_info=False, options=None): def _no_super_reset(self, seed=None, options=None):
self.np_random.random() # generates a new prng self.np_random.random() # generates a new prng
# generate seed deterministic result # generate seed deterministic result
self.observation_space.seed(0) self.observation_space.seed(0)
return self.observation_space.sample() return self.observation_space.sample(), {}
def _super_reset_fixed(self, seed=None, return_info=False, options=None): def _super_reset_fixed(self, seed=None, options=None):
# Call super that ignores the seed passed, use fixed seed # Call super that ignores the seed passed, use fixed seed
super(GenericTestEnv, self).reset(seed=1) super(GenericTestEnv, self).reset(seed=1)
# deterministic output # deterministic output
self.observation_space._np_random = self.np_random self.observation_space._np_random = self.np_random
return self.observation_space.sample() return self.observation_space.sample(), {}
def _reset_default_seed( def _reset_default_seed(self: GenericTestEnv, seed="Error", options=None):
self: GenericTestEnv, seed="Error", return_info=False, options=None
):
super(GenericTestEnv, self).reset(seed=seed) super(GenericTestEnv, self).reset(seed=seed)
self.observation_space._np_random = ( # pyright: ignore [reportPrivateUsage] self.observation_space._np_random = ( # pyright: ignore [reportPrivateUsage]
self.np_random self.np_random
) )
return self.observation_space.sample() return self.observation_space.sample(), {}
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -78,12 +81,12 @@ def _reset_default_seed(
[ [
[ [
gym.error.Error, gym.error.Error,
lambda self: self.observation_space.sample(), lambda self: (self.observation_space.sample(), {}),
"The `reset` method does not provide a `seed` or `**kwargs` keyword argument.", "The `reset` method does not provide a `seed` or `**kwargs` keyword argument.",
], ],
[ [
AssertionError, AssertionError,
lambda self, seed, *_: self.observation_space.sample(), lambda self, seed, *_: (self.observation_space.sample(), {}),
"Expects the random number generator to have been generated given a seed was passed to reset. Mostly likely the environment reset function does not call `super().reset(seed=seed)`.", "Expects the random number generator to have been generated given a seed was passed to reset. Mostly likely the environment reset function does not call `super().reset(seed=seed)`.",
], ],
[ [
@@ -115,95 +118,125 @@ def test_check_reset_seed(test, func: callable, message: str):
check_reset_seed(GenericTestEnv(reset_fn=func)) check_reset_seed(GenericTestEnv(reset_fn=func))
def _deprecated_return_info(
self, return_info: bool = False
) -> Union[Tuple[ObsType, dict], ObsType]:
"""function to simulate the signature and behavior of a `reset` function with the deprecated `return_info` optional argument"""
if return_info:
return self.observation_space.sample(), {}
else:
return self.observation_space.sample()
def _reset_var_keyword_kwargs(self, kwargs): def _reset_var_keyword_kwargs(self, kwargs):
return self.observation_space.sample() return self.observation_space.sample(), {}
def _reset_return_info_type(self, seed=None, return_info=False, options=None): def _reset_return_info_type(self, seed=None, options=None):
if return_info: """Returns a `list` instead of a `tuple`. This function is used to make sure `env_checker` correctly
return [1, 2] checks that the return type of `env.reset()` is a `tuple`"""
else: return [self.observation_space.sample(), {}]
return self.observation_space.sample()
def _reset_return_info_length(self, seed=None, return_info=False, options=None): def _reset_return_info_length(self, seed=None, options=None):
if return_info: return 1, 2, 3
return 1, 2, 3
else:
return self.observation_space.sample()
def _return_info_obs_outside(self, seed=None, return_info=False, options=None): def _return_info_obs_outside(self, seed=None, options=None):
if return_info: return self.observation_space.sample() + self.observation_space.high, {}
return self.observation_space.sample() + self.observation_space.high, {}
else:
return self.observation_space.sample()
def _return_info_not_dict(self, seed=None, return_info=False, options=None): def _return_info_not_dict(self, seed=None, options=None):
if return_info: return self.observation_space.sample(), ["key", "value"]
return self.observation_space.sample(), ["key", "value"]
else:
return self.observation_space.sample()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"test,func,message", "test,func,message",
[ [
[
gym.error.Error,
lambda self, *_: self.observation_space.sample(),
"The `reset` method does not provide a `return_info` or `**kwargs` keyword argument.",
],
[
gym.error.Error,
_reset_var_keyword_kwargs,
"The `reset` method does not provide a `return_info` or `**kwargs` keyword argument.",
],
[ [
AssertionError, AssertionError,
_reset_return_info_type, _reset_return_info_type,
"Calling the reset method with `return_info=True` did not return a tuple, actual type: <class 'list'>", "The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `<class 'list'>`",
], ],
[ [
AssertionError, AssertionError,
_reset_return_info_length, _reset_return_info_length,
"Calling the reset method with `return_info=True` did not return a 2-tuple, actual length: 3", "Calling the reset method did not return a 2-tuple, actual length: 3",
], ],
[ [
AssertionError, AssertionError,
_return_info_obs_outside, _return_info_obs_outside,
"The first element returned by `env.reset(return_info=True)` is not within the observation space.", "The first element returned by `env.reset()` is not within the observation space.",
], ],
[ [
AssertionError, AssertionError,
_return_info_not_dict, _return_info_not_dict,
"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: <class 'list'>", "The second element returned by `env.reset()` was not a dictionary, actual type: <class 'list'>",
], ],
], ],
) )
def test_check_reset_info(test, func: callable, message: str): def test_check_reset_return_type(test, func: callable, message: str):
"""Tests the check reset info function works as expected.""" """Tests the check `env.reset()` function has a correct return type."""
if test is UserWarning:
with pytest.warns( with pytest.raises(test, match=f"^{re.escape(message)}$"):
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$" check_reset_return_type(GenericTestEnv(reset_fn=func))
):
check_reset_info(GenericTestEnv(reset_fn=func))
else: @pytest.mark.parametrize(
with pytest.raises(test, match=f"^{re.escape(message)}$"): "test,func,message",
check_reset_info(GenericTestEnv(reset_fn=func)) [
[
UserWarning,
_deprecated_return_info,
"`return_info` is deprecated as an optional argument to `reset`. `reset`"
"should now always return `obs, info` where `obs` is an observation, and `info` is a dictionary"
"containing additional information.",
],
],
)
def test_check_reset_return_info_deprecation(test, func: callable, message: str):
"""Tests that return_info has been correct deprecated as an argument to `env.reset()`."""
with pytest.warns(test, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"):
check_reset_return_info_deprecation(GenericTestEnv(reset_fn=func))
def test_check_seed_deprecation():
"""Tests that `check_seed_deprecation()` throws a warning if `env.seed()` has not been removed."""
message = """Official support for the `seed` function is dropped. Standard practice is to reset gym environments using `env.reset(seed=<desired seed>)`"""
env = GenericTestEnv()
def seed(seed):
return
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
env.seed = seed
assert callable(env.seed)
check_seed_deprecation(env)
with warnings.catch_warnings(record=True) as caught_warnings:
env.seed = []
check_seed_deprecation(env)
env.seed = 123
check_seed_deprecation(env)
del env.seed
check_seed_deprecation(env)
assert len(caught_warnings) == 0
def test_check_reset_options(): def test_check_reset_options():
"""Tests the check_reset_options function.""" """Tests the check_reset_options function."""
with pytest.raises( with pytest.raises(
gym.error.Error, gym.error.Error,
match=re.escape( match=re.escape(
"The `reset` method does not provide an `options` or `**kwargs` keyword argument" "The `reset` method does not provide an `options` or `**kwargs` keyword argument"
), ),
): ):
check_reset_options(GenericTestEnv(reset_fn=lambda self: 0)) check_reset_options(GenericTestEnv(reset_fn=lambda self: (0, {})))
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@@ -242,24 +242,20 @@ def test_check_obs(test, obs, obs_space: spaces.Space, message: str):
assert len(warnings) == 0 assert len(warnings) == 0
def _reset_no_seed(self, return_info=False, options=None): def _reset_no_seed(self, options=None):
return self.observation_space.sample() return self.observation_space.sample(), {}
def _reset_seed_default(self, seed="error", return_info=False, options=None): def _reset_seed_default(self, seed="error", options=None):
return self.observation_space.sample() return self.observation_space.sample(), {}
def _reset_no_return_info(self, seed=None, options=None): def _reset_no_option(self, seed=None):
return self.observation_space.sample() return self.observation_space.sample(), {}
def _reset_no_option(self, seed=None, return_info=False):
return self.observation_space.sample()
def _make_reset_results(results): def _make_reset_results(results):
def _reset_result(self, seed=None, return_info=False, options=None): def _reset_result(self, seed=None, options=None):
return results return results
return _reset_result return _reset_result
@@ -280,12 +276,6 @@ def _make_reset_results(results):
"The default seed argument in `Env.reset` should be `None`, otherwise the environment will by default always be deterministic. Actual default: seed='error'", "The default seed argument in `Env.reset` should be `None`, otherwise the environment will by default always be deterministic. Actual default: seed='error'",
{}, {},
], ],
[
UserWarning,
_reset_no_return_info,
"Future gym versions will require that `Env.reset` can be passed `return_info` to return information from the environment resetting.",
{},
],
[ [
UserWarning, UserWarning,
_reset_no_option, _reset_no_option,
@@ -293,16 +283,16 @@ def _make_reset_results(results):
{}, {},
], ],
[ [
AssertionError, UserWarning,
_make_reset_results([0, {}]), _make_reset_results([0, {}]),
"The result returned by `env.reset(return_info=True)` was not a tuple, actual type: <class 'list'>", "The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `<class 'list'>`",
{"return_info": True}, {},
], ],
[ [
AssertionError, AssertionError,
_make_reset_results((0, {1, 2})), _make_reset_results((np.array([0], dtype=np.float32), {1, 2})),
"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: <class 'set'>", "The second element returned by `env.reset()` was not a dictionary, actual type: <class 'set'>",
{"return_info": True}, {},
], ],
], ],
) )
@@ -317,6 +307,8 @@ def test_passive_env_reset_checker(test, func: callable, message: str, kwargs: D
with pytest.warns(None) as warnings: with pytest.warns(None) as warnings:
with pytest.raises(test, match=f"^{re.escape(message)}$"): with pytest.raises(test, match=f"^{re.escape(message)}$"):
env_reset_passive_checker(GenericTestEnv(reset_fn=func), **kwargs) env_reset_passive_checker(GenericTestEnv(reset_fn=func), **kwargs)
for warning in warnings:
print(warning)
assert len(warnings) == 0 assert len(warnings) == 0

View File

@@ -28,7 +28,7 @@ def test_reset_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(8)] env_fns = [make_env("CartPole-v1", i) for i in range(8)]
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations = env.reset() observations, infos = env.reset()
env.close() env.close()
@@ -40,19 +40,7 @@ def test_reset_async_vector_env(shared_memory):
try: try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations = env.reset(return_info=False) observations, infos = env.reset()
finally:
env.close()
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations, infos = env.reset(return_info=True)
finally: finally:
env.close() env.close()
@@ -143,7 +131,7 @@ def test_copy_async_vector_env(shared_memory):
# TODO, these tests do nothing, understand the purpose of the tests and fix them # TODO, these tests do nothing, understand the purpose of the tests and fix them
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=True) env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=True)
observations = env.reset() observations, infos = env.reset()
observations[0] = 0 observations[0] = 0
env.close() env.close()
@@ -155,7 +143,7 @@ def test_no_copy_async_vector_env(shared_memory):
# TODO, these tests do nothing, understand the purpose of the tests and fix them # TODO, these tests do nothing, understand the purpose of the tests and fix them
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=False) env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=False)
observations = env.reset() observations, infos = env.reset()
observations[0] = 0 observations[0] = 0
env.close() env.close()
@@ -268,7 +256,7 @@ def test_custom_space_async_vector_env():
env_fns = [make_custom_space_env(i) for i in range(4)] env_fns = [make_custom_space_env(i) for i in range(4)]
env = AsyncVectorEnv(env_fns, shared_memory=False) env = AsyncVectorEnv(env_fns, shared_memory=False)
reset_observations = env.reset() reset_observations, reset_infos = env.reset()
assert isinstance(env.single_action_space, CustomSpace) assert isinstance(env.single_action_space, CustomSpace)
assert isinstance(env.action_space, Tuple) assert isinstance(env.action_space, Tuple)

View File

@@ -12,7 +12,7 @@ class OldStepEnv(gym.Env):
self.observation_space = Discrete(2) self.observation_space = Discrete(2)
def reset(self): def reset(self):
return 0 return 0, {}
def step(self, action): def step(self, action):
obs = self.observation_space.sample() obs = self.observation_space.sample()
@@ -28,7 +28,7 @@ class NewStepEnv(gym.Env):
self.observation_space = Discrete(2) self.observation_space = Discrete(2)
def reset(self): def reset(self):
return 0 return 0, {}
def step(self, action): def step(self, action):
obs = self.observation_space.sample() obs = self.observation_space.sample()

View File

@@ -24,7 +24,7 @@ def test_create_sync_vector_env():
def test_reset_sync_vector_env(): def test_reset_sync_vector_env():
env_fns = [make_env("CartPole-v1", i) for i in range(8)] env_fns = [make_env("CartPole-v1", i) for i in range(8)]
env = SyncVectorEnv(env_fns) env = SyncVectorEnv(env_fns)
observations = env.reset() observations, infos = env.reset()
env.close() env.close()
assert isinstance(env.observation_space, Box) assert isinstance(env.observation_space, Box)
@@ -35,32 +35,6 @@ def test_reset_sync_vector_env():
del observations del observations
env = SyncVectorEnv(env_fns)
observations = env.reset(return_info=False)
env.close()
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
del observations
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
env = SyncVectorEnv(env_fns)
observations, infos = env.reset(return_info=True)
env.close()
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
assert isinstance(infos, dict)
assert all([isinstance(info, dict) for info in infos])
@pytest.mark.parametrize("use_single_action_space", [True, False]) @pytest.mark.parametrize("use_single_action_space", [True, False])
def test_step_sync_vector_env(use_single_action_space): def test_step_sync_vector_env(use_single_action_space):
@@ -145,7 +119,7 @@ def test_custom_space_sync_vector_env():
env_fns = [make_custom_space_env(i) for i in range(4)] env_fns = [make_custom_space_env(i) for i in range(4)]
env = SyncVectorEnv(env_fns) env = SyncVectorEnv(env_fns)
reset_observations = env.reset() reset_observations, infos = env.reset()
assert isinstance(env.single_action_space, CustomSpace) assert isinstance(env.single_action_space, CustomSpace)
assert isinstance(env.action_space, Tuple) assert isinstance(env.action_space, Tuple)

View File

@@ -22,8 +22,8 @@ def test_vector_env_equal(shared_memory):
assert async_env.action_space == sync_env.action_space assert async_env.action_space == sync_env.action_space
assert async_env.single_action_space == sync_env.single_action_space assert async_env.single_action_space == sync_env.single_action_space
async_observations = async_env.reset(seed=0) async_observations, async_infos = async_env.reset(seed=0)
sync_observations = sync_env.reset(seed=0) sync_observations, sync_infos = sync_env.reset(seed=0)
assert np.all(async_observations == sync_observations) assert np.all(async_observations == sync_observations)
for _ in range(num_steps): for _ in range(num_steps):

View File

@@ -63,7 +63,7 @@ class UnittestSlowEnv(gym.Env):
super().reset(seed=seed) super().reset(seed=seed)
if self.slow_reset > 0: if self.slow_reset > 0:
time.sleep(self.slow_reset) time.sleep(self.slow_reset)
return self.observation_space.sample() return self.observation_space.sample(), {}
def step(self, action): def step(self, action):
time.sleep(action) time.sleep(action)
@@ -99,7 +99,7 @@ class CustomSpaceEnv(gym.Env):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed) super().reset(seed=seed)
return "reset" return "reset", {}
def step(self, action): def step(self, action):
observation = f"step({action:s})" observation = f"step({action:s})"

View File

@@ -86,9 +86,7 @@ def test_atari_preprocessing_grayscale(env, obs_shape):
# It is not possible to test the outputs as we are not using actual observations. # It is not possible to test the outputs as we are not using actual observations.
# todo: update when ale-py is compatible with the ci # todo: update when ale-py is compatible with the ci
obs = env.reset(seed=0) obs, _ = env.reset(seed=0)
assert obs in env.observation_space
obs, _ = env.reset(seed=0, return_info=True)
assert obs in env.observation_space assert obs in env.observation_space
obs, _, _, _ = env.step(env.action_space.sample()) obs, _, _, _ = env.step(env.action_space.sample())
@@ -110,7 +108,7 @@ def test_atari_preprocessing_scale(grayscale, scaled, max_test_steps=10):
noop_max=0, noop_max=0,
) )
obs = env.reset() obs, _ = env.reset()
max_obs = 1 if scaled else 255 max_obs = 1 if scaled else 255
assert np.all(0 <= obs) and np.all(obs <= max_obs) assert np.all(0 <= obs) and np.all(obs <= max_obs)

View File

@@ -39,19 +39,10 @@ class DummyResetEnv(gym.Env):
{"count": self.count}, # Info {"count": self.count}, # Info
) )
def reset( def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
self,
*,
seed: Optional[int] = None,
return_info: Optional[bool] = False,
options: Optional[dict] = None
):
"""Resets the DummyEnv to return the count array and info with count.""" """Resets the DummyEnv to return the count array and info with count."""
self.count = 0 self.count = 0
if not return_info: return np.array([self.count]), {"count": self.count}
return np.array([self.count])
else:
return np.array([self.count]), {"count": self.count}
def unwrap_env(env) -> Generator[gym.Wrapper, None, None]: def unwrap_env(env) -> Generator[gym.Wrapper, None, None]:
@@ -113,7 +104,7 @@ def test_autoreset_wrapper_autoreset():
env = DummyResetEnv() env = DummyResetEnv()
env = AutoResetWrapper(env) env = AutoResetWrapper(env)
obs, info = env.reset(return_info=True) obs, info = env.reset()
assert obs == np.array([0]) assert obs == np.array([0])
assert info == {"count": 0} assert info == {"count": 0}

View File

@@ -25,16 +25,10 @@ class FakeEnvironment(gym.Env):
image_shape = (32, 32, 3) image_shape = (32, 32, 3)
return np.zeros(image_shape, dtype=np.uint8) return np.zeros(image_shape, dtype=np.uint8)
def reset( def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
super().reset(seed=seed) super().reset(seed=seed)
observation = self.observation_space.sample() observation = self.observation_space.sample()
return observation if not return_info else (observation, {}) return observation, {}
def step(self, action): def step(self, action):
del action del action
@@ -79,8 +73,9 @@ class TestFilterObservation:
assert tuple(wrapped_env.observation_space.spaces.keys()) == tuple(filter_keys) assert tuple(wrapped_env.observation_space.spaces.keys()) == tuple(filter_keys)
# Check that the added space item is consistent with the added observation. # Check that the added space item is consistent with the added observation.
observation = wrapped_env.reset() observation, info = wrapped_env.reset()
assert len(observation) == len(filter_keys) assert len(observation) == len(filter_keys)
assert isinstance(info, dict)
@pytest.mark.parametrize("filter_keys,error_type,error_match", ERROR_TEST_CASES) @pytest.mark.parametrize("filter_keys,error_type,error_match", ERROR_TEST_CASES)
def test_raises_with_incorrect_arguments( def test_raises_with_incorrect_arguments(

View File

@@ -18,7 +18,7 @@ class FakeEnvironment(gym.Env):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed) super().reset(seed=seed)
self.observation = self.observation_space.sample() self.observation = self.observation_space.sample()
return self.observation return self.observation, {}
OBSERVATION_SPACES = ( OBSERVATION_SPACES = (
@@ -67,7 +67,7 @@ class TestFlattenEnvironment:
""" """
env = FakeEnvironment(observation_space=observation_space) env = FakeEnvironment(observation_space=observation_space)
wrapped_env = FlattenObservation(env) wrapped_env = FlattenObservation(env)
flattened = wrapped_env.reset() flattened, info = wrapped_env.reset()
unflattened = unflatten(env.observation_space, flattened) unflattened = unflatten(env.observation_space, flattened)
original = env.observation original = env.observation

View File

@@ -11,11 +11,13 @@ def test_flatten_observation(env_id):
env = gym.make(env_id, disable_env_checker=True) env = gym.make(env_id, disable_env_checker=True)
wrapped_env = FlattenObservation(env) wrapped_env = FlattenObservation(env)
obs = env.reset() obs, info = env.reset()
wrapped_obs = wrapped_env.reset() wrapped_obs, wrapped_obs_info = wrapped_env.reset()
space = spaces.Tuple((spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2))) space = spaces.Tuple((spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2)))
wrapped_space = spaces.Box(0, 1, [32 + 11 + 2], dtype=np.int64) wrapped_space = spaces.Box(0, 1, [32 + 11 + 2], dtype=np.int64)
assert space.contains(obs) assert space.contains(obs)
assert wrapped_space.contains(wrapped_obs) assert wrapped_space.contains(wrapped_obs)
assert isinstance(info, dict)
assert isinstance(wrapped_obs_info, dict)

View File

@@ -33,8 +33,8 @@ def test_frame_stack(env_id, num_stack, lz4_compress):
dup = gym.make(env_id, disable_env_checker=True) dup = gym.make(env_id, disable_env_checker=True)
obs = env.reset(seed=0) obs, _ = env.reset(seed=0)
dup_obs = dup.reset(seed=0) dup_obs, _ = dup.reset(seed=0)
assert np.allclose(obs[-1], dup_obs) assert np.allclose(obs[-1], dup_obs)
for _ in range(num_stack**2): for _ in range(num_stack**2):

View File

@@ -22,5 +22,5 @@ def test_gray_scale_observation(env_id, keep_dim):
else: else:
assert len(wrapped_env.observation_space.shape) == 2 assert len(wrapped_env.observation_space.shape) == 2
wrapped_obs = wrapped_env.reset() wrapped_obs, info = wrapped_env.reset()
assert wrapped_obs in wrapped_env.observation_space assert wrapped_obs in wrapped_env.observation_space

View File

@@ -23,7 +23,7 @@ class FakeEnvironment(gym.Env):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed) super().reset(seed=seed)
observation = self.observation_space.sample() observation = self.observation_space.sample()
return observation return observation, {}
def step(self, action): def step(self, action):
del action del action
@@ -115,5 +115,6 @@ class TestNestedDictWrapper:
def test_nested_dicts_ravel(self, observation_space, flat_shape): def test_nested_dicts_ravel(self, observation_space, flat_shape):
env = FakeEnvironment(observation_space=observation_space) env = FakeEnvironment(observation_space=observation_space)
wrapped_env = FlattenObservation(FilterObservation(env, env.obs_keys)) wrapped_env = FlattenObservation(FilterObservation(env, env.obs_keys))
obs = wrapped_env.reset() obs, info = wrapped_env.reset()
assert obs.shape == wrapped_env.observation_space.shape assert obs.shape == wrapped_env.observation_space.shape
assert isinstance(info, dict)

View File

@@ -1,7 +1,6 @@
from typing import Optional from typing import Optional
import numpy as np import numpy as np
import pytest
from numpy.testing import assert_almost_equal from numpy.testing import assert_almost_equal
import gym import gym
@@ -24,19 +23,10 @@ class DummyRewardEnv(gym.Env):
self.t += 1 self.t += 1
return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {} return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {}
def reset( def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
self,
*,
seed: Optional[int] = None,
return_info: Optional[bool] = False,
options: Optional[dict] = None
):
super().reset(seed=seed) super().reset(seed=seed)
self.t = self.return_reward_idx self.t = self.return_reward_idx
if not return_info: return np.array([self.t]), {}
return np.array([self.t])
else:
return np.array([self.t]), {}
def make_env(return_reward_idx): def make_env(return_reward_idx):
@@ -47,11 +37,10 @@ def make_env(return_reward_idx):
return thunk return thunk
@pytest.mark.parametrize("return_info", [False, True]) def test_normalize_observation():
def test_normalize_observation(return_info: bool):
env = DummyRewardEnv(return_reward_idx=0) env = DummyRewardEnv(return_reward_idx=0)
env = NormalizeObservation(env) env = NormalizeObservation(env)
env.reset(return_info=return_info) env.reset()
env.step(env.action_space.sample()) env.step(env.action_space.sample())
assert_almost_equal(env.obs_rms.mean, 0.5, decimal=4) assert_almost_equal(env.obs_rms.mean, 0.5, decimal=4)
env.step(env.action_space.sample()) env.step(env.action_space.sample())
@@ -61,13 +50,7 @@ def test_normalize_observation(return_info: bool):
def test_normalize_reset_info(): def test_normalize_reset_info():
env = DummyRewardEnv(return_reward_idx=0) env = DummyRewardEnv(return_reward_idx=0)
env = NormalizeObservation(env) env = NormalizeObservation(env)
obs = env.reset() obs, info = env.reset()
assert isinstance(obs, np.ndarray)
del obs
obs = env.reset(return_info=False)
assert isinstance(obs, np.ndarray)
del obs
obs, info = env.reset(return_info=True)
assert isinstance(obs, np.ndarray) assert isinstance(obs, np.ndarray)
assert isinstance(info, dict) assert isinstance(info, dict)

View File

@@ -57,8 +57,8 @@ def test_initialise_failures(env, message):
env.close() env.close()
def _reset_failure(self, seed=None, return_info=False, options=None): def _reset_failure(self, seed=None, options=None):
return np.array([-1.0], dtype=np.float32) return np.array([-1.0], dtype=np.float32), {}
def _step_failure(self, action): def _step_failure(self, action):

View File

@@ -21,7 +21,7 @@ class FakeEnvironment(gym.Env):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed) super().reset(seed=seed)
observation = self.observation_space.sample() observation = self.observation_space.sample()
return observation return observation, {}
def step(self, action): def step(self, action):
del action del action
@@ -82,9 +82,10 @@ def test_dict_observation(pixels_only):
assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys
# Check that the added space item is consistent with the added observation. # Check that the added space item is consistent with the added observation.
observation = wrapped_env.reset() observation, info = wrapped_env.reset()
rgb_observation = observation[pixel_key] rgb_observation = observation[pixel_key]
assert isinstance(info, dict)
assert rgb_observation.shape == (height, width, 3) assert rgb_observation.shape == (height, width, 3)
assert rgb_observation.dtype == np.uint8 assert rgb_observation.dtype == np.uint8
@@ -113,9 +114,10 @@ def test_single_array_observation(pixels_only):
pixel_key, pixel_key,
] ]
observation = wrapped_env.reset() observation, info = wrapped_env.reset()
depth_observation = observation[pixel_key] depth_observation = observation[pixel_key]
assert isinstance(info, dict)
assert depth_observation.shape == (32, 32, 3) assert depth_observation.shape == (32, 32, 3)
assert depth_observation.dtype == np.uint8 assert depth_observation.dtype == np.uint8

View File

@@ -31,10 +31,7 @@ def test_record_episode_statistics_reset_info():
env = gym.make("CartPole-v1", disable_env_checker=True) env = gym.make("CartPole-v1", disable_env_checker=True)
env = RecordEpisodeStatistics(env) env = RecordEpisodeStatistics(env)
ob_space = env.observation_space ob_space = env.observation_space
obs = env.reset() obs, info = env.reset()
assert ob_space.contains(obs)
del obs
obs, info = env.reset(return_info=True)
assert ob_space.contains(obs) assert ob_space.contains(obs)
assert isinstance(info, dict) assert isinstance(info, dict)

View File

@@ -23,35 +23,17 @@ def test_record_video_using_default_trigger():
shutil.rmtree("videos") shutil.rmtree("videos")
def test_record_video_reset_return_info(): def test_record_video_reset():
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True) env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0) env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
ob_space = env.observation_space ob_space = env.observation_space
obs, info = env.reset(return_info=True) obs, info = env.reset()
env.close() env.close()
assert os.path.isdir("videos") assert os.path.isdir("videos")
shutil.rmtree("videos") shutil.rmtree("videos")
assert ob_space.contains(obs) assert ob_space.contains(obs)
assert isinstance(info, dict) assert isinstance(info, dict)
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
ob_space = env.observation_space
obs = env.reset(return_info=False)
env.close()
assert os.path.isdir("videos")
shutil.rmtree("videos")
assert ob_space.contains(obs)
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
ob_space = env.observation_space
obs = env.reset()
env.close()
assert os.path.isdir("videos")
shutil.rmtree("videos")
assert ob_space.contains(obs)
def test_record_video_step_trigger(): def test_record_video_step_trigger():
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True) env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)

View File

@@ -18,8 +18,8 @@ def test_rescale_action():
seed = 0 seed = 0
obs = env.reset(seed=seed) obs, info = env.reset(seed=seed)
wrapped_obs = wrapped_env.reset(seed=seed) wrapped_obs, wrapped_obs_info = wrapped_env.reset(seed=seed)
assert np.allclose(obs, wrapped_obs) assert np.allclose(obs, wrapped_obs)
obs, reward, _, _ = env.step([1.5]) obs, reward, _, _ = env.step([1.5])

View File

@@ -13,7 +13,7 @@ def test_resize_observation(env_id, shape):
assert isinstance(env.observation_space, spaces.Box) assert isinstance(env.observation_space, spaces.Box)
assert env.observation_space.shape[-1] == 3 assert env.observation_space.shape[-1] == 3
obs = env.reset() obs, _ = env.reset()
if isinstance(shape, int): if isinstance(shape, int):
assert env.observation_space.shape[:2] == (shape, shape) assert env.observation_space.shape[:2] == (shape, shape)
assert obs.shape == (shape, shape, 3) assert obs.shape == (shape, shape, 3)

View File

@@ -14,8 +14,8 @@ def test_time_aware_observation(env_id):
assert isinstance(wrapped_env.observation_space, spaces.Box) assert isinstance(wrapped_env.observation_space, spaces.Box)
assert wrapped_env.observation_space.shape[0] == env.observation_space.shape[0] + 1 assert wrapped_env.observation_space.shape[0] == env.observation_space.shape[0] + 1
obs = env.reset() obs, info = env.reset()
wrapped_obs = wrapped_env.reset() wrapped_obs, wrapped_obs_info = wrapped_env.reset()
assert wrapped_env.t == 0.0 assert wrapped_env.t == 0.0
assert wrapped_obs[-1] == 0.0 assert wrapped_obs[-1] == 0.0
assert wrapped_obs.shape[0] == obs.shape[0] + 1 assert wrapped_obs.shape[0] == obs.shape[0] + 1
@@ -30,7 +30,7 @@ def test_time_aware_observation(env_id):
assert wrapped_obs[-1] == 2.0 assert wrapped_obs[-1] == 2.0
assert wrapped_obs.shape[0] == obs.shape[0] + 1 assert wrapped_obs.shape[0] == obs.shape[0] + 1
wrapped_obs = wrapped_env.reset() wrapped_obs, wrapped_obs_info = wrapped_env.reset()
assert wrapped_env.t == 0.0 assert wrapped_env.t == 0.0
assert wrapped_obs[-1] == 0.0 assert wrapped_obs[-1] == 0.0
assert wrapped_obs.shape[0] == obs.shape[0] + 1 assert wrapped_obs.shape[0] == obs.shape[0] + 1

View File

@@ -9,13 +9,7 @@ def test_time_limit_reset_info():
env = gym.make("CartPole-v1", disable_env_checker=True) env = gym.make("CartPole-v1", disable_env_checker=True)
env = TimeLimit(env) env = TimeLimit(env)
ob_space = env.observation_space ob_space = env.observation_space
obs = env.reset() obs, info = env.reset()
assert ob_space.contains(obs)
del obs
obs = env.reset(return_info=False)
assert ob_space.contains(obs)
del obs
obs, info = env.reset(return_info=True)
assert ob_space.contains(obs) assert ob_space.contains(obs)
assert isinstance(info, dict) assert isinstance(info, dict)

View File

@@ -15,9 +15,10 @@ def test_transform_observation(env_id):
gym.make(env_id, disable_env_checker=True), lambda obs: affine_transform(obs) gym.make(env_id, disable_env_checker=True), lambda obs: affine_transform(obs)
) )
obs = env.reset(seed=0) obs, info = env.reset(seed=0)
wrapped_obs = wrapped_env.reset(seed=0) wrapped_obs, wrapped_obs_info = wrapped_env.reset(seed=0)
assert np.allclose(wrapped_obs, affine_transform(obs)) assert np.allclose(wrapped_obs, affine_transform(obs))
assert isinstance(wrapped_obs_info, dict)
action = env.action_space.sample() action = env.action_space.sample()
obs, reward, done, _ = env.step(action) obs, reward, done, _ = env.step(action)

View File

@@ -23,7 +23,7 @@ def test_info_to_list():
env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True) env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True)
wrapped_env = VectorListInfo(env_to_wrap) wrapped_env = VectorListInfo(env_to_wrap)
wrapped_env.action_space.seed(SEED) wrapped_env.action_space.seed(SEED)
_, info = wrapped_env.reset(seed=SEED, return_info=True) _, info = wrapped_env.reset(seed=SEED)
assert isinstance(info, list) assert isinstance(info, list)
assert len(info) == NUM_ENVS assert len(info) == NUM_ENVS
@@ -40,7 +40,7 @@ def test_info_to_list():
def test_info_to_list_statistics(): def test_info_to_list_statistics():
env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True) env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True)
wrapped_env = VectorListInfo(RecordEpisodeStatistics(env_to_wrap)) wrapped_env = VectorListInfo(RecordEpisodeStatistics(env_to_wrap))
_, info = wrapped_env.reset(seed=SEED, return_info=True) _, info = wrapped_env.reset(seed=SEED)
wrapped_env.action_space.seed(SEED) wrapped_env.action_space.seed(SEED)
assert isinstance(info, list) assert isinstance(info, list)
assert len(info) == NUM_ENVS assert len(info) == NUM_ENVS