Removing return_info argument to env.reset() and deprecated env.seed() function (reset now always returns info) (#2962)

* removed return_info, made info dict mandatory in reset

* tenatively removed deprecated seed api for environments

* added more info type checks to wrapper tests

* formatting/style compliance

* addressed some comments

* polish to address review

* fixed tests after merge, and added a test of the return_info deprecation assertion if found in reset signature

* some organization of env_checker tests, reverted a probably merge error

* added deprecation check for seed function in env

* updated docstring

* removed debug prints, tweaked test_check_seed_deprecation

* changed return_info deprecation check from assertion to warning

* fixes to vector envs, now  should be correctly structured

* added some explanation and typehints for mockup depcreated return info reset function

* re-removed seed function from vector envs

* added explanation to _reset_return_info_type and changed the return statement
This commit is contained in:
John Balis
2022-08-23 11:09:54 -04:00
committed by GitHub
parent 1f864789fd
commit 3a8daafce1
56 changed files with 327 additions and 639 deletions

View File

@@ -23,14 +23,14 @@ The Gym API's API models environments as simple Python `env` classes. Creating e
```python
import gym
env = gym.make("CartPole-v1")
observation, info = env.reset(seed=42, return_info=True)
observation, info = env.reset(seed=42)
for _ in range(1000):
action = env.action_space.sample()
observation, reward, done, info = env.step(action)
if done:
observation, info = env.reset(return_info=True)
observation, info = env.reset()
env.close()
```

View File

@@ -41,11 +41,10 @@ class Env(Generic[ObsType, ActType]):
The main API methods that users of this class need to know are:
- :meth:`step` - Takes a step in the environment using an action returning the next observation, reward,
if the environment terminated and more information.
- :meth:`reset` - Resets the environment to an initial state, returning the initial observation.
if the environment terminated and observation information.
- :meth:`reset` - Resets the environment to an initial state, returning the initial observation and observation information.
- :meth:`render` - Renders the environment observation with modes depending on the output
- :meth:`close` - Closes the environment, important for rendering where pygame is imported
- :meth:`seed` - Seeds the environment's random number generator, :deprecated: in favor of `Env.reset(seed=seed)`.
And set the following attributes:
@@ -124,9 +123,8 @@ class Env(Generic[ObsType, ActType]):
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
) -> Union[ObsType, Tuple[ObsType, dict]]:
) -> Tuple[ObsType, dict]:
"""Resets the environment to an initial state and returns the initial observation.
This method can reset the environment's random number generator(s) if ``seed`` is an integer or
@@ -143,8 +141,6 @@ class Env(Generic[ObsType, ActType]):
If you pass an integer, the PRNG will be reset even if it already exists.
Usually, you want to pass an integer *right after the environment has been initialized and then never again*.
Please refer to the minimal example above to see this paradigm in action.
return_info (bool): If true, return additional information along with initial observation.
This info should be analogous to the info returned in :meth:`step`
options (optional dict): Additional information to specify how the environment is reset (optional,
depending on the specific environment)
@@ -152,8 +148,7 @@ class Env(Generic[ObsType, ActType]):
Returns:
observation (object): Observation of the initial state. This will be an element of :attr:`observation_space`
(typically a numpy array) and is analogous to the observation returned by :meth:`step`.
info (optional dictionary): This will *only* be returned if ``return_info=True`` is passed.
It contains auxiliary information complementing ``observation``. This dictionary should be analogous to
info (dictionary): This dictionary contains auxiliary information complementing ``observation``. It should be analogous to
the ``info`` returned by :meth:`step`.
"""
# Initialize the RNG if the seed is manually passed
@@ -193,33 +188,6 @@ class Env(Generic[ObsType, ActType]):
"""
pass
def seed(self, seed=None):
""":deprecated: function that sets the seed for the environment's random number generator(s).
Use `env.reset(seed=seed)` as the new API for setting the seed of the environment.
Note:
Some environments use multiple pseudorandom number generators.
We want to capture all such seeds used in order to ensure that
there aren't accidental correlations between multiple generators.
Args:
seed(Optional int): The seed value for the random number generator
Returns:
seeds (List[int]): Returns the list of seeds used in this environment's random
number generators. The first value in the list should be the
"main" seed, or the value which a reproducer should pass to
'seed'. Often, the main seed equals the provided 'seed', but
this won't be true `if seed=None`, for example.
"""
deprecation(
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
"Please use `env.reset(seed=seed)` instead."
)
self._np_random, seed = seeding.np_random(seed)
return [seed]
@property
def unwrapped(self) -> "Env":
"""Returns the base non-wrapped environment.
@@ -370,7 +338,7 @@ class Wrapper(Env[ObsType, ActType]):
return step_api_compatibility(self.env.step(action), self.new_step_api)
def reset(self, **kwargs) -> Union[ObsType, Tuple[ObsType, dict]]:
def reset(self, **kwargs) -> Tuple[ObsType, dict]:
"""Resets the environment with kwargs."""
return self.env.reset(**kwargs)
@@ -384,10 +352,6 @@ class Wrapper(Env[ObsType, ActType]):
"""Closes the environment."""
return self.env.close()
def seed(self, seed=None):
"""Seeds the environment."""
return self.env.seed(seed)
def __str__(self):
"""Returns the wrapper name and the unwrapped environment string."""
return f"<{type(self).__name__}{self.env}>"
@@ -432,11 +396,8 @@ class ObservationWrapper(Wrapper):
def reset(self, **kwargs):
"""Resets the environment, returning a modified observation using :meth:`self.observation`."""
if kwargs.get("return_info", False):
obs, info = self.env.reset(**kwargs)
return self.observation(obs), info
else:
return self.observation(self.env.reset(**kwargs))
obs, info = self.env.reset(**kwargs)
return self.observation(obs), info
def step(self, action):
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""

View File

@@ -428,7 +428,6 @@ class BipedalWalker(gym.Env, EzPickle):
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed)
@@ -514,10 +513,7 @@ class BipedalWalker(gym.Env, EzPickle):
self.lidar = [LidarCallback() for _ in range(10)]
self.renderer.reset()
if not return_info:
return self.step(np.array([0, 0, 0, 0]))[0]
else:
return self.step(np.array([0, 0, 0, 0]))[0], {}
return self.step(np.array([0, 0, 0, 0]))[0], {}
def step(self, action: np.ndarray):
assert self.hull is not None

View File

@@ -483,7 +483,6 @@ class CarRacing(gym.Env, EzPickle):
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed)
@@ -519,10 +518,7 @@ class CarRacing(gym.Env, EzPickle):
self.car = Car(self.world, *self.track[0][1:4])
self.renderer.reset()
if not return_info:
return self.step(None)[0]
else:
return self.step(None)[0], {}
return self.step(None)[0], {}
def step(self, action: Union[np.ndarray, int]):
assert self.car is not None

View File

@@ -305,7 +305,6 @@ class LunarLander(gym.Env, EzPickle):
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed)
@@ -413,10 +412,7 @@ class LunarLander(gym.Env, EzPickle):
self.drawlist = [self.lander] + self.legs
self.renderer.reset()
if not return_info:
return self.step(np.array([0, 0]) if self.continuous else 0)[0]
else:
return self.step(np.array([0, 0]) if self.continuous else 0)[0], {}
return self.step(np.array([0, 0]) if self.continuous else 0)[0], {}
def _create_particle(self, mass, x, y, ttl):
p = self.world.CreateDynamicBody(
@@ -774,7 +770,7 @@ def demo_heuristic_lander(env, seed=None, render=False):
total_reward = 0
steps = 0
s = env.reset(seed=seed)
s, info = env.reset(seed=seed)
while True:
a = heuristic(env, s)
s, r, terminated, truncated, info = step_api_compatibility(env.step(a), True)

View File

@@ -180,13 +180,7 @@ class AcrobotEnv(core.Env):
self.action_space = spaces.Discrete(3)
self.state = None
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
# Note that if you use custom reset bounds, it may lead to out-of-bound
# state/observations.
@@ -199,10 +193,7 @@ class AcrobotEnv(core.Env):
self.renderer.reset()
self.renderer.render_step()
if not return_info:
return self._get_ob()
else:
return self._get_ob(), {}
return self._get_ob(), {}
def step(self, a):
s = self.state

View File

@@ -192,7 +192,6 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed)
@@ -205,10 +204,7 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
self.steps_beyond_terminated = None
self.renderer.reset()
self.renderer.render_step()
if not return_info:
return np.array(self.state, dtype=np.float32)
else:
return np.array(self.state, dtype=np.float32), {}
return np.array(self.state, dtype=np.float32), {}
def render(self):
return self.renderer.get_renders()

View File

@@ -174,13 +174,7 @@ class Continuous_MountainCarEnv(gym.Env):
self.renderer.render_step()
return self.state, reward, terminated, False, {}
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
# Note that if you use custom reset bounds, it may lead to out-of-bound
# state/observations.
@@ -188,10 +182,7 @@ class Continuous_MountainCarEnv(gym.Env):
self.state = np.array([self.np_random.uniform(low=low, high=high), 0])
self.renderer.reset()
self.renderer.render_step()
if not return_info:
return np.array(self.state, dtype=np.float32)
else:
return np.array(self.state, dtype=np.float32), {}
return np.array(self.state, dtype=np.float32), {}
def _height(self, xs):
return np.sin(3 * xs) * 0.45 + 0.55

View File

@@ -152,7 +152,6 @@ class MountainCarEnv(gym.Env):
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed)
@@ -162,10 +161,7 @@ class MountainCarEnv(gym.Env):
self.state = np.array([self.np_random.uniform(low=low, high=high), 0])
self.renderer.reset()
self.renderer.render_step()
if not return_info:
return np.array(self.state, dtype=np.float32)
else:
return np.array(self.state, dtype=np.float32), {}
return np.array(self.state, dtype=np.float32), {}
def _height(self, xs):
return np.sin(3 * xs) * 0.45 + 0.55

View File

@@ -138,13 +138,7 @@ class PendulumEnv(gym.Env):
self.renderer.render_step()
return self._get_obs(), -costs, False, False, {}
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
if options is None:
high = np.array([DEFAULT_X, DEFAULT_Y])
@@ -162,10 +156,7 @@ class PendulumEnv(gym.Env):
self.renderer.reset()
self.renderer.render_step()
if not return_info:
return self._get_obs()
else:
return self._get_obs(), {}
return self._get_obs(), {}
def _get_obs(self):
theta, thetadot = self.state

View File

@@ -142,7 +142,6 @@ class BaseMujocoEnv(gym.Env):
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed)
@@ -152,10 +151,7 @@ class BaseMujocoEnv(gym.Env):
ob = self.reset_model()
self.renderer.reset()
self.renderer.render_step()
if not return_info:
return ob
else:
return ob, {}
return ob, {}
def set_state(self, qpos, qvel):
"""

View File

@@ -167,7 +167,6 @@ class BlackjackEnv(gym.Env):
def reset(
self,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed)
@@ -189,10 +188,7 @@ class BlackjackEnv(gym.Env):
self.renderer.reset()
self.renderer.render_step()
if not return_info:
return self._get_obs()
else:
return self._get_obs(), {}
return self._get_obs(), {}
def render(self):
return self.renderer.get_renders()

View File

@@ -149,22 +149,14 @@ class CliffWalkingEnv(Env):
self.renderer.render_step()
return (int(s), r, t, False, {"prob": p})
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
self.lastaction = None
self.renderer.reset()
self.renderer.render_step()
if not return_info:
return int(self.s)
else:
return int(self.s), {"prob": 1}
return int(self.s), {"prob": 1}
def render(self):
return self.renderer.get_renders()

View File

@@ -256,7 +256,6 @@ class FrozenLakeEnv(Env):
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed)
@@ -266,10 +265,7 @@ class FrozenLakeEnv(Env):
self.renderer.reset()
self.renderer.render_step()
if not return_info:
return int(self.s)
else:
return int(self.s), {"prob": 1}
return int(self.s), {"prob": 1}
def render(self):
return self.renderer.get_renders()

View File

@@ -89,7 +89,7 @@ class TaxiEnv(Env):
### Info
``step`` and ``reset(return_info=True)`` will return an info dictionary that contains "p" and "action_mask" containing
``step`` and ``reset()`` will return an info dictionary that contains "p" and "action_mask" containing
the probability that the state is taken and a mask of what actions will result in a change of state to speed up training.
As Taxi's initial state is a stochastic, the "p" key represents the probability of the
@@ -266,7 +266,6 @@ class TaxiEnv(Env):
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed)
@@ -275,10 +274,8 @@ class TaxiEnv(Env):
self.taxi_orientation = 0
self.renderer.reset()
self.renderer.render_step()
if not return_info:
return int(self.s)
else:
return int(self.s), {"prob": 1.0, "action_mask": self.action_mask(self.s)}
return int(self.s), {"prob": 1.0, "action_mask": self.action_mask(self.s)}
def render(self):
return self.renderer.get_renders()

View File

@@ -73,7 +73,7 @@ def check_reset_seed(env: gym.Env):
and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD
):
try:
obs_1 = env.reset(seed=123)
obs_1, info = env.reset(seed=123)
assert (
obs_1 in env.observation_space
), "The observation returned by `env.reset(seed=123)` is not within the observation space."
@@ -85,7 +85,7 @@ def check_reset_seed(env: gym.Env):
env.unwrapped._np_random # pyright: ignore [reportPrivateUsage]
)
obs_2 = env.reset(seed=123)
obs_2, info = env.reset(seed=123)
assert (
obs_2 in env.observation_space
), "The observation returned by `env.reset(seed=123)` is not within the observation space."
@@ -98,7 +98,7 @@ def check_reset_seed(env: gym.Env):
== seed_123_rng.bit_generator.state
), "Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random generates are not same when the same seeds are passed to `env.reset`."
obs_3 = env.reset(seed=456)
obs_3, info = env.reset(seed=456)
assert (
obs_3 in env.observation_space
), "The observation returned by `env.reset(seed=456)` is not within the observation space."
@@ -126,53 +126,6 @@ def check_reset_seed(env: gym.Env):
)
def check_reset_info(env: gym.Env):
"""Checks that :meth:`reset` supports the ``return_info`` keyword.
Args:
env: The environment to check
Raises:
AssertionError: The environment cannot be reset with `return_info=True`,
even though `return_info` or `kwargs` appear in the signature.
"""
signature = inspect.signature(env.reset)
if "return_info" in signature.parameters or (
"kwargs" in signature.parameters
and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD
):
try:
obs = env.reset(return_info=False)
assert (
obs in env.observation_space
), "The value returned by `env.reset(return_info=True)` is not within the observation space."
result = env.reset(return_info=True)
assert isinstance(
result, tuple
), f"Calling the reset method with `return_info=True` did not return a tuple, actual type: {type(result)}"
assert (
len(result) == 2
), f"Calling the reset method with `return_info=True` did not return a 2-tuple, actual length: {len(result)}"
obs, info = result
assert (
obs in env.observation_space
), "The first element returned by `env.reset(return_info=True)` is not within the observation space."
assert isinstance(
info, dict
), f"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: {type(info)}"
except TypeError as e:
raise AssertionError(
"The environment cannot be reset with `return_info=True`, even though `return_info` or `kwargs` appear in the signature. "
f"This should never happen, please report this issue. The error was: {e}"
)
else:
raise gym.error.Error(
"The `reset` method does not provide a `return_info` or `**kwargs` keyword argument."
)
def check_reset_options(env: gym.Env):
"""Check that the environment can be reset with options.
@@ -201,6 +154,64 @@ def check_reset_options(env: gym.Env):
)
def check_reset_return_info_deprecation(env: gym.Env):
"""Makes sure support for deprecated `return_info` argument is dropped.
Args:
env: The environment to check
Raises:
UserWarning
"""
signature = inspect.signature(env.reset)
if "return_info" in signature.parameters:
logger.warn(
"`return_info` is deprecated as an optional argument to `reset`. `reset`"
"should now always return `obs, info` where `obs` is an observation, and `info` is a dictionary"
"containing additional information."
)
def check_seed_deprecation(env: gym.Env):
"""Makes sure support for deprecated function `seed` is dropped.
Args:
env: The environment to check
Raises:
UserWarning
"""
seed_fn = getattr(env, "seed", None)
if callable(seed_fn):
logger.warn(
"Official support for the `seed` function is dropped. "
"Standard practice is to reset gym environments using `env.reset(seed=<desired seed>)`"
)
def check_reset_return_type(env: gym.Env):
"""Checks that :meth:`reset` correctly returns a tuple of the form `(obs , info)`.
Args:
env: The environment to check
Raises:
AssertionError depending on spec violation
"""
result = env.reset()
assert isinstance(
result, tuple
), f"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `{type(result)}`"
assert (
len(result) == 2
), f"Calling the reset method did not return a 2-tuple, actual length: {len(result)}"
obs, info = result
assert (
obs in env.observation_space
), "The first element returned by `env.reset()` is not within the observation space."
assert isinstance(
info, dict
), f"The second element returned by `env.reset()` was not a dictionary, actual type: {type(info)}"
def check_space_limit(space, space_type: str):
"""Check the space limit for only the Box space as a test that only runs as part of `check_env`."""
if isinstance(space, spaces.Box):
@@ -279,9 +290,11 @@ def check_env(env: gym.Env, warn: bool = None, skip_render_check: bool = False):
check_space_limit(env.observation_space, "observation")
# ==== Check the reset method ====
check_seed_deprecation(env)
check_reset_return_info_deprecation(env)
check_reset_return_type(env)
check_reset_seed(env)
check_reset_options(env)
check_reset_info(env)
# ============ Check the returned values ===============
env_reset_passive_checker(env)

View File

@@ -183,14 +183,6 @@ def env_reset_passive_checker(env, **kwargs):
f"Actual default: {seed_param}"
)
if "return_info" not in signature.parameters and not (
"kwargs" in signature.parameters
and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD
):
logger.warn(
"Future gym versions will require that `Env.reset` can be passed `return_info` to return information from the environment resetting."
)
if "options" not in signature.parameters and "kwargs" not in signature.parameters:
logger.warn(
"Future gym versions will require that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information."
@@ -198,21 +190,17 @@ def env_reset_passive_checker(env, **kwargs):
# Checks the result of env.reset with kwargs
result = env.reset(**kwargs)
if kwargs.get("return_info", False) is True:
assert isinstance(
result, tuple
), f"The result returned by `env.reset(return_info=True)` was not a tuple, actual type: {type(result)}"
assert (
len(result) == 2
), f"The length of the result returned by `env.reset(return_info=True)` is not 2, actual length: {len(result)}"
obs, info = result
assert isinstance(
info, dict
), f"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: {type(info)}"
else:
obs = result
if not isinstance(result, tuple):
logger.warn(
f"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `{type(result)}`"
)
obs, info = result
check_obs(obs, env.observation_space, "reset")
assert isinstance(
info, dict
), f"The second element returned by `env.reset()` was not a dictionary, actual type: {type(info)}"
return result

View File

@@ -171,38 +171,9 @@ class AsyncVectorEnv(VectorEnv):
self._state = AsyncState.DEFAULT
self._check_spaces()
def seed(self, seed=None):
"""Seeds the vector environments.
Args:
seed: The seeds use with the environments
Raises:
AlreadyPendingCallError: Calling `seed` while waiting for a pending call to complete
"""
super().seed(seed=seed)
self._assert_is_running()
if seed is None:
seed = [None for _ in range(self.num_envs)]
if isinstance(seed, int):
seed = [seed + i for i in range(self.num_envs)]
assert len(seed) == self.num_envs
if self._state != AsyncState.DEFAULT:
raise AlreadyPendingCallError(
f"Calling `seed` while waiting for a pending call to `{self._state.value}` to complete.",
self._state.value,
)
for pipe, seed in zip(self.parent_pipes, seed):
pipe.send(("seed", seed))
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
def reset_async(
self,
seed: Optional[Union[int, List[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
"""Send calls to the :obj:`reset` methods of the sub-environments.
@@ -211,7 +182,6 @@ class AsyncVectorEnv(VectorEnv):
Args:
seed: List of seeds for each environment
return_info: If to return information
options: The reset option
Raises:
@@ -238,8 +208,6 @@ class AsyncVectorEnv(VectorEnv):
single_kwargs = {}
if single_seed is not None:
single_kwargs["seed"] = single_seed
if return_info:
single_kwargs["return_info"] = return_info
if options is not None:
single_kwargs["options"] = options
@@ -250,7 +218,6 @@ class AsyncVectorEnv(VectorEnv):
self,
timeout: Optional[Union[int, float]] = None,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
) -> Union[ObsType, Tuple[ObsType, List[dict]]]:
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
@@ -258,7 +225,6 @@ class AsyncVectorEnv(VectorEnv):
Args:
timeout: Number of seconds before the call to `reset_wait` times out. If `None`, the call to `reset_wait` never times out.
seed: ignored
return_info: If to return information
options: ignored
Returns:
@@ -286,27 +252,17 @@ class AsyncVectorEnv(VectorEnv):
self._raise_if_errors(successes)
self._state = AsyncState.DEFAULT
if return_info:
infos = {}
results, info_data = zip(*results)
for i, info in enumerate(info_data):
infos = self._add_info(infos, info, i)
infos = {}
results, info_data = zip(*results)
for i, info in enumerate(info_data):
infos = self._add_info(infos, info, i)
if not self.shared_memory:
self.observations = concatenate(
self.single_observation_space, results, self.observations
)
if not self.shared_memory:
self.observations = concatenate(
self.single_observation_space, results, self.observations
)
return (
deepcopy(self.observations) if self.copy else self.observations
), infos
else:
if not self.shared_memory:
self.observations = concatenate(
self.single_observation_space, results, self.observations
)
return deepcopy(self.observations) if self.copy else self.observations
return (deepcopy(self.observations) if self.copy else self.observations), infos
def step_async(self, actions: np.ndarray):
"""Send the calls to :obj:`step` to each sub-environment.
@@ -606,12 +562,8 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
while True:
command, data = pipe.recv()
if command == "reset":
if "return_info" in data and data["return_info"] is True:
observation, info = env.reset(**data)
pipe.send(((observation, info), True))
else:
observation = env.reset(**data)
pipe.send((observation, True))
observation, info = env.reset(**data)
pipe.send(((observation, info), True))
elif command == "step":
(
@@ -622,8 +574,9 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
info,
) = step_api_compatibility(env.step(data), True)
if terminated or truncated:
info["final_observation"] = observation
observation = env.reset()
old_observation = observation
observation, info = env.reset()
info["final_observation"] = old_observation
pipe.send(((observation, reward, terminated, truncated, info), True))
elif command == "seed":
env.seed(data)
@@ -676,18 +629,12 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
while True:
command, data = pipe.recv()
if command == "reset":
if "return_info" in data and data["return_info"] is True:
observation, info = env.reset(**data)
write_to_shared_memory(
observation_space, index, observation, shared_memory
)
pipe.send(((None, info), True))
else:
observation = env.reset(**data)
write_to_shared_memory(
observation_space, index, observation, shared_memory
)
pipe.send((None, True))
observation, info = env.reset(**data)
write_to_shared_memory(
observation_space, index, observation, shared_memory
)
pipe.send(((None, info), True))
elif command == "step":
(
observation,
@@ -697,8 +644,10 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
info,
) = step_api_compatibility(env.step(data), True)
if terminated or truncated:
info["final_observation"] = observation
observation = env.reset()
old_observation = observation
observation, info = env.reset()
info["final_observation"] = old_observation
write_to_shared_memory(
observation_space, index, observation, shared_memory
)

View File

@@ -93,14 +93,12 @@ class SyncVectorEnv(VectorEnv):
def reset_wait(
self,
seed: Optional[Union[int, List[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
Args:
seed: The reset environment seed
return_info: If to return information
options: Option information for the environment reset
Returns:
@@ -123,26 +121,15 @@ class SyncVectorEnv(VectorEnv):
kwargs["seed"] = single_seed
if options is not None:
kwargs["options"] = options
if return_info is True:
kwargs["return_info"] = return_info
if not return_info:
observation = env.reset(**kwargs)
observations.append(observation)
else:
observation, info = env.reset(**kwargs)
observations.append(observation)
infos = self._add_info(infos, info, i)
observation, info = env.reset(**kwargs)
observations.append(observation)
infos = self._add_info(infos, info, i)
self.observations = concatenate(
self.single_observation_space, observations, self.observations
)
if not return_info:
return deepcopy(self.observations) if self.copy else self.observations
else:
return (
deepcopy(self.observations) if self.copy else self.observations
), infos
return (deepcopy(self.observations) if self.copy else self.observations), infos
def step_async(self, actions):
"""Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version."""
@@ -164,8 +151,9 @@ class SyncVectorEnv(VectorEnv):
info,
) = step_api_compatibility(env.step(action), True)
if self._terminateds[i] or self._truncateds[i]:
info["final_observation"] = observation
observation = env.reset()
old_observation = observation
observation, info = env.reset()
info["final_observation"] = old_observation
observations.append(observation)
infos = self._add_info(infos, info, i)
self.observations = concatenate(

View File

@@ -60,7 +60,6 @@ class VectorEnv(gym.Env):
def reset_async(
self,
seed: Optional[Union[int, List[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
"""Reset the sub-environments asynchronously.
@@ -70,7 +69,6 @@ class VectorEnv(gym.Env):
Args:
seed: The reset seed
return_info: If to return info
options: Reset options
"""
pass
@@ -78,7 +76,6 @@ class VectorEnv(gym.Env):
def reset_wait(
self,
seed: Optional[Union[int, List[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
"""Retrieves the results of a :meth:`reset_async` call.
@@ -87,7 +84,6 @@ class VectorEnv(gym.Env):
Args:
seed: The reset seed
return_info: If to return info
options: Reset options
Returns:
@@ -102,21 +98,19 @@ class VectorEnv(gym.Env):
self,
*,
seed: Optional[Union[int, List[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
"""Reset all parallel environments and return a batch of initial observations.
Args:
seed: The environment reset seeds
return_info: If to return the info
options: If to return the options
Returns:
A batch of observations from the vectorized environment.
"""
self.reset_async(seed=seed, return_info=return_info, options=options)
return self.reset_wait(seed=seed, return_info=return_info, options=options)
self.reset_async(seed=seed, options=options)
return self.reset_wait(seed=seed, options=options)
def step_async(self, actions):
"""Asynchronously performs steps in the sub-environments.
@@ -220,21 +214,6 @@ class VectorEnv(gym.Env):
self.close_extras(**kwargs)
self.closed = True
def seed(self, seed=None):
"""Set the random seed in all parallel environments.
Args:
seed: Random seed for each parallel environment. If ``seed`` is a list of
length ``num_envs``, then the items of the list are chosen as random
seeds. If ``seed`` is an int, then each parallel environment uses the random
seed ``seed + n``, where ``n`` is the index of the parallel environment
(between ``0`` and ``num_envs - 1``).
"""
deprecation(
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
"Please use `env.reset(seed=seed) instead in VectorEnvs."
)
def _add_info(self, infos: dict, info: dict, env_num: int) -> dict:
"""Add env info to the info dictionary of the vectorized environment.
@@ -339,9 +318,6 @@ class VectorEnvWrapper(VectorEnv):
def close_extras(self, **kwargs):
return self.env.close_extras(**kwargs)
def seed(self, seed=None):
return self.env.seed(seed)
def call(self, name, *args, **kwargs):
return self.env.call(name, *args, **kwargs)

View File

@@ -151,11 +151,7 @@ class AtariPreprocessing(gym.Wrapper):
def reset(self, **kwargs):
"""Resets the environment using preprocessing."""
# NoopReset
if kwargs.get("return_info", False):
_, reset_info = self.env.reset(**kwargs)
else:
_ = self.env.reset(**kwargs)
reset_info = {}
_, reset_info = self.env.reset(**kwargs)
noops = (
self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
@@ -168,11 +164,7 @@ class AtariPreprocessing(gym.Wrapper):
)
reset_info.update(step_info)
if terminated or truncated:
if kwargs.get("return_info", False):
_, reset_info = self.env.reset(**kwargs)
else:
_ = self.env.reset(**kwargs)
reset_info = {}
_, reset_info = self.env.reset(**kwargs)
self.lives = self.ale.lives()
if self.grayscale_obs:
@@ -181,10 +173,7 @@ class AtariPreprocessing(gym.Wrapper):
self.ale.getScreenRGB(self.obs_buffer[0])
self.obs_buffer[1].fill(0)
if kwargs.get("return_info", False):
return self._get_obs(), reset_info
else:
return self._get_obs()
return self._get_obs(), reset_info
def _get_obs(self):
if self.frame_skip > 1: # more efficient in-place pooling

View File

@@ -48,7 +48,7 @@ class AutoResetWrapper(gym.Wrapper):
if terminated or truncated:
new_obs, new_info = self.env.reset(return_info=True)
new_obs, new_info = self.env.reset()
assert (
"final_observation" not in new_info
), 'info dict cannot contain key "final_observation" '

View File

@@ -191,14 +191,8 @@ class FrameStack(gym.ObservationWrapper):
Returns:
The stacked observations
"""
if kwargs.get("return_info", False):
obs, info = self.env.reset(**kwargs)
else:
obs = self.env.reset(**kwargs)
info = None # Unused
obs, info = self.env.reset(**kwargs)
[self.frames.append(obs) for _ in range(self.num_stack)]
if kwargs.get("return_info", False):
return self.observation(None), info
else:
return self.observation(None)
return self.observation(None), info

View File

@@ -89,20 +89,12 @@ class NormalizeObservation(gym.core.Wrapper):
def reset(self, **kwargs):
"""Resets the environment and normalizes the observation."""
if kwargs.get("return_info", False):
obs, info = self.env.reset(**kwargs)
obs, info = self.env.reset(**kwargs)
if self.is_vector_env:
return self.normalize(obs), info
else:
return self.normalize(np.array([obs]))[0], info
if self.is_vector_env:
return self.normalize(obs), info
else:
obs = self.env.reset(**kwargs)
if self.is_vector_env:
return self.normalize(obs)
else:
return self.normalize(np.array([obs]))[0]
return self.normalize(np.array([obs]))[0], info
def normalize(self, obs):
"""Normalises the observation using the running mean and variance of the observations."""

View File

@@ -57,9 +57,6 @@ class VectorListInfo(gym.Wrapper):
def reset(self, **kwargs):
"""Resets the environment using kwargs."""
if not kwargs.get("return_info"):
return self.env.reset(**kwargs)
obs, infos = self.env.reset(**kwargs)
list_info = self._convert_info_to_list(infos)
return obs, list_info

View File

@@ -139,7 +139,7 @@ def test_taxi_action_mask():
def test_taxi_encode_decode():
env = TaxiEnv()
state = env.reset()
state, info = env.reset()
for _ in range(100):
assert (
env.encode(*env.decode(state)) == state

View File

@@ -82,8 +82,8 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
env_1 = env_spec.make(disable_env_checker=True)
env_2 = env_spec.make(disable_env_checker=True)
initial_obs_1 = env_1.reset(seed=SEED)
initial_obs_2 = env_2.reset(seed=SEED)
initial_obs_1, initial_info_1 = env_1.reset(seed=SEED)
initial_obs_2, initial_info_2 = env_2.reset(seed=SEED)
assert_equals(initial_obs_1, initial_obs_2)
env_1.action_space.seed(SEED)

View File

@@ -17,8 +17,8 @@ def verify_environments_match(
old_env = envs.make(old_env_id, disable_env_checker=True)
new_env = envs.make(new_env_id, disable_env_checker=True)
old_reset_obs = old_env.reset(seed=seed)
new_reset_obs = new_env.reset(seed=seed)
old_reset_obs, old_info = old_env.reset(seed=seed)
new_reset_obs, new_info = new_env.reset(seed=seed)
np.testing.assert_allclose(old_reset_obs, new_reset_obs)
@@ -56,7 +56,7 @@ EXCLUDE_POS_FROM_OBS = [
def test_obs_space_mujoco_environments(env_spec: EnvSpec):
"""Check that the returned observations are contained in the observation space of the environment"""
env = env_spec.make(disable_env_checker=True)
reset_obs = env.reset()
reset_obs, info = env.reset()
assert env.observation_space.contains(
reset_obs
), f"Obseravtion returned by reset() of {env_spec.id} is not contained in the default observation space {env.observation_space}."
@@ -73,7 +73,7 @@ def test_obs_space_mujoco_environments(env_spec: EnvSpec):
env = env_spec.make(
disable_env_checker=True, exclude_current_positions_from_observation=False
)
reset_obs = env.reset()
reset_obs, info = env.reset()
assert env.observation_space.contains(
reset_obs
), f"Obseravtion of {env_spec.id} is not contained in the default observation space {env.observation_space} when excluding current position from observation."
@@ -86,7 +86,7 @@ def test_obs_space_mujoco_environments(env_spec: EnvSpec):
# Ant-v4 has the option of including contact forces in the observation space with the use_contact_forces argument
if env_spec.name == "Ant" and env_spec.version == 4:
env = env_spec.make(disable_env_checker=True, use_contact_forces=True)
reset_obs = env.reset()
reset_obs, info = env.reset()
assert env.observation_space.contains(
reset_obs
), f"Obseravtion of {env_spec.id} is not contained in the default observation space {env.observation_space} when using contact forces."

View File

@@ -21,17 +21,9 @@ class UnittestEnv(core.Env):
observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8)
action_space = spaces.Discrete(3)
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
if return_info:
return self.observation_space.sample(), {"info": "dummy"}
return self.observation_space.sample() # Dummy observation
return self.observation_space.sample(), {"info": "dummy"}
def step(self, action):
observation = self.observation_space.sample() # Dummy observation
@@ -45,22 +37,13 @@ class UnknownSpacesEnv(core.Env):
on external resources), it is not encouraged.
"""
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
self.observation_space = spaces.Box(
low=0, high=255, shape=(64, 64, 3), dtype=np.uint8
)
self.action_space = spaces.Discrete(3)
if not return_info:
return self.observation_space.sample() # Dummy observation
else:
return self.observation_space.sample(), {} # Dummy observation with info
return self.observation_space.sample(), {} # Dummy observation with info
def step(self, action):
observation = self.observation_space.sample() # Dummy observation

View File

@@ -12,16 +12,12 @@ def basic_reset_fn(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
) -> Union[ObsType, Tuple[ObsType, dict]]:
"""A basic reset function that will pass the environment check using random actions from the observation space."""
super(GenericTestEnv, self).reset(seed=seed)
self.observation_space.seed(seed)
if return_info:
return self.observation_space.sample(), {"options": options}
else:
return self.observation_space.sample()
return self.observation_space.sample(), {"options": options}
def new_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
@@ -77,7 +73,6 @@ class GenericTestEnv(gym.Env):
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
) -> Union[ObsType, Tuple[ObsType, dict]]:
# If you need a default working reset function, use `basic_reset_fn` above

View File

@@ -1,16 +1,21 @@
"""Tests that the `env_checker` runs as expects and all errors are possible."""
import re
import warnings
from typing import Tuple, Union
import numpy as np
import pytest
import gym
from gym import spaces
from gym.core import ObsType
from gym.utils.env_checker import (
check_env,
check_reset_info,
check_reset_options,
check_reset_return_info_deprecation,
check_reset_return_type,
check_reset_seed,
check_seed_deprecation,
)
from tests.testing_env import GenericTestEnv
@@ -48,29 +53,27 @@ def test_no_error_warnings(env):
assert len(warnings) == 0, [warning.message for warning in warnings]
def _no_super_reset(self, seed=None, return_info=False, options=None):
def _no_super_reset(self, seed=None, options=None):
self.np_random.random() # generates a new prng
# generate seed deterministic result
self.observation_space.seed(0)
return self.observation_space.sample()
return self.observation_space.sample(), {}
def _super_reset_fixed(self, seed=None, return_info=False, options=None):
def _super_reset_fixed(self, seed=None, options=None):
# Call super that ignores the seed passed, use fixed seed
super(GenericTestEnv, self).reset(seed=1)
# deterministic output
self.observation_space._np_random = self.np_random
return self.observation_space.sample()
return self.observation_space.sample(), {}
def _reset_default_seed(
self: GenericTestEnv, seed="Error", return_info=False, options=None
):
def _reset_default_seed(self: GenericTestEnv, seed="Error", options=None):
super(GenericTestEnv, self).reset(seed=seed)
self.observation_space._np_random = ( # pyright: ignore [reportPrivateUsage]
self.np_random
)
return self.observation_space.sample()
return self.observation_space.sample(), {}
@pytest.mark.parametrize(
@@ -78,12 +81,12 @@ def _reset_default_seed(
[
[
gym.error.Error,
lambda self: self.observation_space.sample(),
lambda self: (self.observation_space.sample(), {}),
"The `reset` method does not provide a `seed` or `**kwargs` keyword argument.",
],
[
AssertionError,
lambda self, seed, *_: self.observation_space.sample(),
lambda self, seed, *_: (self.observation_space.sample(), {}),
"Expects the random number generator to have been generated given a seed was passed to reset. Mostly likely the environment reset function does not call `super().reset(seed=seed)`.",
],
[
@@ -115,95 +118,125 @@ def test_check_reset_seed(test, func: callable, message: str):
check_reset_seed(GenericTestEnv(reset_fn=func))
def _deprecated_return_info(
self, return_info: bool = False
) -> Union[Tuple[ObsType, dict], ObsType]:
"""function to simulate the signature and behavior of a `reset` function with the deprecated `return_info` optional argument"""
if return_info:
return self.observation_space.sample(), {}
else:
return self.observation_space.sample()
def _reset_var_keyword_kwargs(self, kwargs):
return self.observation_space.sample()
return self.observation_space.sample(), {}
def _reset_return_info_type(self, seed=None, return_info=False, options=None):
if return_info:
return [1, 2]
else:
return self.observation_space.sample()
def _reset_return_info_type(self, seed=None, options=None):
"""Returns a `list` instead of a `tuple`. This function is used to make sure `env_checker` correctly
checks that the return type of `env.reset()` is a `tuple`"""
return [self.observation_space.sample(), {}]
def _reset_return_info_length(self, seed=None, return_info=False, options=None):
if return_info:
return 1, 2, 3
else:
return self.observation_space.sample()
def _reset_return_info_length(self, seed=None, options=None):
return 1, 2, 3
def _return_info_obs_outside(self, seed=None, return_info=False, options=None):
if return_info:
return self.observation_space.sample() + self.observation_space.high, {}
else:
return self.observation_space.sample()
def _return_info_obs_outside(self, seed=None, options=None):
return self.observation_space.sample() + self.observation_space.high, {}
def _return_info_not_dict(self, seed=None, return_info=False, options=None):
if return_info:
return self.observation_space.sample(), ["key", "value"]
else:
return self.observation_space.sample()
def _return_info_not_dict(self, seed=None, options=None):
return self.observation_space.sample(), ["key", "value"]
@pytest.mark.parametrize(
"test,func,message",
[
[
gym.error.Error,
lambda self, *_: self.observation_space.sample(),
"The `reset` method does not provide a `return_info` or `**kwargs` keyword argument.",
],
[
gym.error.Error,
_reset_var_keyword_kwargs,
"The `reset` method does not provide a `return_info` or `**kwargs` keyword argument.",
],
[
AssertionError,
_reset_return_info_type,
"Calling the reset method with `return_info=True` did not return a tuple, actual type: <class 'list'>",
"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `<class 'list'>`",
],
[
AssertionError,
_reset_return_info_length,
"Calling the reset method with `return_info=True` did not return a 2-tuple, actual length: 3",
"Calling the reset method did not return a 2-tuple, actual length: 3",
],
[
AssertionError,
_return_info_obs_outside,
"The first element returned by `env.reset(return_info=True)` is not within the observation space.",
"The first element returned by `env.reset()` is not within the observation space.",
],
[
AssertionError,
_return_info_not_dict,
"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: <class 'list'>",
"The second element returned by `env.reset()` was not a dictionary, actual type: <class 'list'>",
],
],
)
def test_check_reset_info(test, func: callable, message: str):
"""Tests the check reset info function works as expected."""
if test is UserWarning:
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
check_reset_info(GenericTestEnv(reset_fn=func))
else:
with pytest.raises(test, match=f"^{re.escape(message)}$"):
check_reset_info(GenericTestEnv(reset_fn=func))
def test_check_reset_return_type(test, func: callable, message: str):
"""Tests the check `env.reset()` function has a correct return type."""
with pytest.raises(test, match=f"^{re.escape(message)}$"):
check_reset_return_type(GenericTestEnv(reset_fn=func))
@pytest.mark.parametrize(
"test,func,message",
[
[
UserWarning,
_deprecated_return_info,
"`return_info` is deprecated as an optional argument to `reset`. `reset`"
"should now always return `obs, info` where `obs` is an observation, and `info` is a dictionary"
"containing additional information.",
],
],
)
def test_check_reset_return_info_deprecation(test, func: callable, message: str):
"""Tests that return_info has been correct deprecated as an argument to `env.reset()`."""
with pytest.warns(test, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"):
check_reset_return_info_deprecation(GenericTestEnv(reset_fn=func))
def test_check_seed_deprecation():
"""Tests that `check_seed_deprecation()` throws a warning if `env.seed()` has not been removed."""
message = """Official support for the `seed` function is dropped. Standard practice is to reset gym environments using `env.reset(seed=<desired seed>)`"""
env = GenericTestEnv()
def seed(seed):
return
with pytest.warns(
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
):
env.seed = seed
assert callable(env.seed)
check_seed_deprecation(env)
with warnings.catch_warnings(record=True) as caught_warnings:
env.seed = []
check_seed_deprecation(env)
env.seed = 123
check_seed_deprecation(env)
del env.seed
check_seed_deprecation(env)
assert len(caught_warnings) == 0
def test_check_reset_options():
"""Tests the check_reset_options function."""
with pytest.raises(
gym.error.Error,
match=re.escape(
"The `reset` method does not provide an `options` or `**kwargs` keyword argument"
),
):
check_reset_options(GenericTestEnv(reset_fn=lambda self: 0))
check_reset_options(GenericTestEnv(reset_fn=lambda self: (0, {})))
@pytest.mark.parametrize(

View File

@@ -242,24 +242,20 @@ def test_check_obs(test, obs, obs_space: spaces.Space, message: str):
assert len(warnings) == 0
def _reset_no_seed(self, return_info=False, options=None):
return self.observation_space.sample()
def _reset_no_seed(self, options=None):
return self.observation_space.sample(), {}
def _reset_seed_default(self, seed="error", return_info=False, options=None):
return self.observation_space.sample()
def _reset_seed_default(self, seed="error", options=None):
return self.observation_space.sample(), {}
def _reset_no_return_info(self, seed=None, options=None):
return self.observation_space.sample()
def _reset_no_option(self, seed=None, return_info=False):
return self.observation_space.sample()
def _reset_no_option(self, seed=None):
return self.observation_space.sample(), {}
def _make_reset_results(results):
def _reset_result(self, seed=None, return_info=False, options=None):
def _reset_result(self, seed=None, options=None):
return results
return _reset_result
@@ -280,12 +276,6 @@ def _make_reset_results(results):
"The default seed argument in `Env.reset` should be `None`, otherwise the environment will by default always be deterministic. Actual default: seed='error'",
{},
],
[
UserWarning,
_reset_no_return_info,
"Future gym versions will require that `Env.reset` can be passed `return_info` to return information from the environment resetting.",
{},
],
[
UserWarning,
_reset_no_option,
@@ -293,16 +283,16 @@ def _make_reset_results(results):
{},
],
[
AssertionError,
UserWarning,
_make_reset_results([0, {}]),
"The result returned by `env.reset(return_info=True)` was not a tuple, actual type: <class 'list'>",
{"return_info": True},
"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `<class 'list'>`",
{},
],
[
AssertionError,
_make_reset_results((0, {1, 2})),
"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: <class 'set'>",
{"return_info": True},
_make_reset_results((np.array([0], dtype=np.float32), {1, 2})),
"The second element returned by `env.reset()` was not a dictionary, actual type: <class 'set'>",
{},
],
],
)
@@ -317,6 +307,8 @@ def test_passive_env_reset_checker(test, func: callable, message: str, kwargs: D
with pytest.warns(None) as warnings:
with pytest.raises(test, match=f"^{re.escape(message)}$"):
env_reset_passive_checker(GenericTestEnv(reset_fn=func), **kwargs)
for warning in warnings:
print(warning)
assert len(warnings) == 0

View File

@@ -28,7 +28,7 @@ def test_reset_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations = env.reset()
observations, infos = env.reset()
env.close()
@@ -40,19 +40,7 @@ def test_reset_async_vector_env(shared_memory):
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations = env.reset(return_info=False)
finally:
env.close()
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations, infos = env.reset(return_info=True)
observations, infos = env.reset()
finally:
env.close()
@@ -143,7 +131,7 @@ def test_copy_async_vector_env(shared_memory):
# TODO, these tests do nothing, understand the purpose of the tests and fix them
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=True)
observations = env.reset()
observations, infos = env.reset()
observations[0] = 0
env.close()
@@ -155,7 +143,7 @@ def test_no_copy_async_vector_env(shared_memory):
# TODO, these tests do nothing, understand the purpose of the tests and fix them
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=False)
observations = env.reset()
observations, infos = env.reset()
observations[0] = 0
env.close()
@@ -268,7 +256,7 @@ def test_custom_space_async_vector_env():
env_fns = [make_custom_space_env(i) for i in range(4)]
env = AsyncVectorEnv(env_fns, shared_memory=False)
reset_observations = env.reset()
reset_observations, reset_infos = env.reset()
assert isinstance(env.single_action_space, CustomSpace)
assert isinstance(env.action_space, Tuple)

View File

@@ -12,7 +12,7 @@ class OldStepEnv(gym.Env):
self.observation_space = Discrete(2)
def reset(self):
return 0
return 0, {}
def step(self, action):
obs = self.observation_space.sample()
@@ -28,7 +28,7 @@ class NewStepEnv(gym.Env):
self.observation_space = Discrete(2)
def reset(self):
return 0
return 0, {}
def step(self, action):
obs = self.observation_space.sample()

View File

@@ -24,7 +24,7 @@ def test_create_sync_vector_env():
def test_reset_sync_vector_env():
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
env = SyncVectorEnv(env_fns)
observations = env.reset()
observations, infos = env.reset()
env.close()
assert isinstance(env.observation_space, Box)
@@ -35,32 +35,6 @@ def test_reset_sync_vector_env():
del observations
env = SyncVectorEnv(env_fns)
observations = env.reset(return_info=False)
env.close()
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
del observations
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
env = SyncVectorEnv(env_fns)
observations, infos = env.reset(return_info=True)
env.close()
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
assert isinstance(infos, dict)
assert all([isinstance(info, dict) for info in infos])
@pytest.mark.parametrize("use_single_action_space", [True, False])
def test_step_sync_vector_env(use_single_action_space):
@@ -145,7 +119,7 @@ def test_custom_space_sync_vector_env():
env_fns = [make_custom_space_env(i) for i in range(4)]
env = SyncVectorEnv(env_fns)
reset_observations = env.reset()
reset_observations, infos = env.reset()
assert isinstance(env.single_action_space, CustomSpace)
assert isinstance(env.action_space, Tuple)

View File

@@ -22,8 +22,8 @@ def test_vector_env_equal(shared_memory):
assert async_env.action_space == sync_env.action_space
assert async_env.single_action_space == sync_env.single_action_space
async_observations = async_env.reset(seed=0)
sync_observations = sync_env.reset(seed=0)
async_observations, async_infos = async_env.reset(seed=0)
sync_observations, sync_infos = sync_env.reset(seed=0)
assert np.all(async_observations == sync_observations)
for _ in range(num_steps):

View File

@@ -63,7 +63,7 @@ class UnittestSlowEnv(gym.Env):
super().reset(seed=seed)
if self.slow_reset > 0:
time.sleep(self.slow_reset)
return self.observation_space.sample()
return self.observation_space.sample(), {}
def step(self, action):
time.sleep(action)
@@ -99,7 +99,7 @@ class CustomSpaceEnv(gym.Env):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
return "reset"
return "reset", {}
def step(self, action):
observation = f"step({action:s})"

View File

@@ -86,9 +86,7 @@ def test_atari_preprocessing_grayscale(env, obs_shape):
# It is not possible to test the outputs as we are not using actual observations.
# todo: update when ale-py is compatible with the ci
obs = env.reset(seed=0)
assert obs in env.observation_space
obs, _ = env.reset(seed=0, return_info=True)
obs, _ = env.reset(seed=0)
assert obs in env.observation_space
obs, _, _, _ = env.step(env.action_space.sample())
@@ -110,7 +108,7 @@ def test_atari_preprocessing_scale(grayscale, scaled, max_test_steps=10):
noop_max=0,
)
obs = env.reset()
obs, _ = env.reset()
max_obs = 1 if scaled else 255
assert np.all(0 <= obs) and np.all(obs <= max_obs)

View File

@@ -39,19 +39,10 @@ class DummyResetEnv(gym.Env):
{"count": self.count}, # Info
)
def reset(
self,
*,
seed: Optional[int] = None,
return_info: Optional[bool] = False,
options: Optional[dict] = None
):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
"""Resets the DummyEnv to return the count array and info with count."""
self.count = 0
if not return_info:
return np.array([self.count])
else:
return np.array([self.count]), {"count": self.count}
return np.array([self.count]), {"count": self.count}
def unwrap_env(env) -> Generator[gym.Wrapper, None, None]:
@@ -113,7 +104,7 @@ def test_autoreset_wrapper_autoreset():
env = DummyResetEnv()
env = AutoResetWrapper(env)
obs, info = env.reset(return_info=True)
obs, info = env.reset()
assert obs == np.array([0])
assert info == {"count": 0}

View File

@@ -25,16 +25,10 @@ class FakeEnvironment(gym.Env):
image_shape = (32, 32, 3)
return np.zeros(image_shape, dtype=np.uint8)
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
observation = self.observation_space.sample()
return observation if not return_info else (observation, {})
return observation, {}
def step(self, action):
del action
@@ -79,8 +73,9 @@ class TestFilterObservation:
assert tuple(wrapped_env.observation_space.spaces.keys()) == tuple(filter_keys)
# Check that the added space item is consistent with the added observation.
observation = wrapped_env.reset()
observation, info = wrapped_env.reset()
assert len(observation) == len(filter_keys)
assert isinstance(info, dict)
@pytest.mark.parametrize("filter_keys,error_type,error_match", ERROR_TEST_CASES)
def test_raises_with_incorrect_arguments(

View File

@@ -18,7 +18,7 @@ class FakeEnvironment(gym.Env):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
self.observation = self.observation_space.sample()
return self.observation
return self.observation, {}
OBSERVATION_SPACES = (
@@ -67,7 +67,7 @@ class TestFlattenEnvironment:
"""
env = FakeEnvironment(observation_space=observation_space)
wrapped_env = FlattenObservation(env)
flattened = wrapped_env.reset()
flattened, info = wrapped_env.reset()
unflattened = unflatten(env.observation_space, flattened)
original = env.observation

View File

@@ -11,11 +11,13 @@ def test_flatten_observation(env_id):
env = gym.make(env_id, disable_env_checker=True)
wrapped_env = FlattenObservation(env)
obs = env.reset()
wrapped_obs = wrapped_env.reset()
obs, info = env.reset()
wrapped_obs, wrapped_obs_info = wrapped_env.reset()
space = spaces.Tuple((spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2)))
wrapped_space = spaces.Box(0, 1, [32 + 11 + 2], dtype=np.int64)
assert space.contains(obs)
assert wrapped_space.contains(wrapped_obs)
assert isinstance(info, dict)
assert isinstance(wrapped_obs_info, dict)

View File

@@ -33,8 +33,8 @@ def test_frame_stack(env_id, num_stack, lz4_compress):
dup = gym.make(env_id, disable_env_checker=True)
obs = env.reset(seed=0)
dup_obs = dup.reset(seed=0)
obs, _ = env.reset(seed=0)
dup_obs, _ = dup.reset(seed=0)
assert np.allclose(obs[-1], dup_obs)
for _ in range(num_stack**2):

View File

@@ -22,5 +22,5 @@ def test_gray_scale_observation(env_id, keep_dim):
else:
assert len(wrapped_env.observation_space.shape) == 2
wrapped_obs = wrapped_env.reset()
wrapped_obs, info = wrapped_env.reset()
assert wrapped_obs in wrapped_env.observation_space

View File

@@ -23,7 +23,7 @@ class FakeEnvironment(gym.Env):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
observation = self.observation_space.sample()
return observation
return observation, {}
def step(self, action):
del action
@@ -115,5 +115,6 @@ class TestNestedDictWrapper:
def test_nested_dicts_ravel(self, observation_space, flat_shape):
env = FakeEnvironment(observation_space=observation_space)
wrapped_env = FlattenObservation(FilterObservation(env, env.obs_keys))
obs = wrapped_env.reset()
obs, info = wrapped_env.reset()
assert obs.shape == wrapped_env.observation_space.shape
assert isinstance(info, dict)

View File

@@ -1,7 +1,6 @@
from typing import Optional
import numpy as np
import pytest
from numpy.testing import assert_almost_equal
import gym
@@ -24,19 +23,10 @@ class DummyRewardEnv(gym.Env):
self.t += 1
return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {}
def reset(
self,
*,
seed: Optional[int] = None,
return_info: Optional[bool] = False,
options: Optional[dict] = None
):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
self.t = self.return_reward_idx
if not return_info:
return np.array([self.t])
else:
return np.array([self.t]), {}
return np.array([self.t]), {}
def make_env(return_reward_idx):
@@ -47,11 +37,10 @@ def make_env(return_reward_idx):
return thunk
@pytest.mark.parametrize("return_info", [False, True])
def test_normalize_observation(return_info: bool):
def test_normalize_observation():
env = DummyRewardEnv(return_reward_idx=0)
env = NormalizeObservation(env)
env.reset(return_info=return_info)
env.reset()
env.step(env.action_space.sample())
assert_almost_equal(env.obs_rms.mean, 0.5, decimal=4)
env.step(env.action_space.sample())
@@ -61,13 +50,7 @@ def test_normalize_observation(return_info: bool):
def test_normalize_reset_info():
env = DummyRewardEnv(return_reward_idx=0)
env = NormalizeObservation(env)
obs = env.reset()
assert isinstance(obs, np.ndarray)
del obs
obs = env.reset(return_info=False)
assert isinstance(obs, np.ndarray)
del obs
obs, info = env.reset(return_info=True)
obs, info = env.reset()
assert isinstance(obs, np.ndarray)
assert isinstance(info, dict)

View File

@@ -57,8 +57,8 @@ def test_initialise_failures(env, message):
env.close()
def _reset_failure(self, seed=None, return_info=False, options=None):
return np.array([-1.0], dtype=np.float32)
def _reset_failure(self, seed=None, options=None):
return np.array([-1.0], dtype=np.float32), {}
def _step_failure(self, action):

View File

@@ -21,7 +21,7 @@ class FakeEnvironment(gym.Env):
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
observation = self.observation_space.sample()
return observation
return observation, {}
def step(self, action):
del action
@@ -82,9 +82,10 @@ def test_dict_observation(pixels_only):
assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys
# Check that the added space item is consistent with the added observation.
observation = wrapped_env.reset()
observation, info = wrapped_env.reset()
rgb_observation = observation[pixel_key]
assert isinstance(info, dict)
assert rgb_observation.shape == (height, width, 3)
assert rgb_observation.dtype == np.uint8
@@ -113,9 +114,10 @@ def test_single_array_observation(pixels_only):
pixel_key,
]
observation = wrapped_env.reset()
observation, info = wrapped_env.reset()
depth_observation = observation[pixel_key]
assert isinstance(info, dict)
assert depth_observation.shape == (32, 32, 3)
assert depth_observation.dtype == np.uint8

View File

@@ -31,10 +31,7 @@ def test_record_episode_statistics_reset_info():
env = gym.make("CartPole-v1", disable_env_checker=True)
env = RecordEpisodeStatistics(env)
ob_space = env.observation_space
obs = env.reset()
assert ob_space.contains(obs)
del obs
obs, info = env.reset(return_info=True)
obs, info = env.reset()
assert ob_space.contains(obs)
assert isinstance(info, dict)

View File

@@ -23,35 +23,17 @@ def test_record_video_using_default_trigger():
shutil.rmtree("videos")
def test_record_video_reset_return_info():
def test_record_video_reset():
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
ob_space = env.observation_space
obs, info = env.reset(return_info=True)
obs, info = env.reset()
env.close()
assert os.path.isdir("videos")
shutil.rmtree("videos")
assert ob_space.contains(obs)
assert isinstance(info, dict)
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
ob_space = env.observation_space
obs = env.reset(return_info=False)
env.close()
assert os.path.isdir("videos")
shutil.rmtree("videos")
assert ob_space.contains(obs)
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
ob_space = env.observation_space
obs = env.reset()
env.close()
assert os.path.isdir("videos")
shutil.rmtree("videos")
assert ob_space.contains(obs)
def test_record_video_step_trigger():
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)

View File

@@ -18,8 +18,8 @@ def test_rescale_action():
seed = 0
obs = env.reset(seed=seed)
wrapped_obs = wrapped_env.reset(seed=seed)
obs, info = env.reset(seed=seed)
wrapped_obs, wrapped_obs_info = wrapped_env.reset(seed=seed)
assert np.allclose(obs, wrapped_obs)
obs, reward, _, _ = env.step([1.5])

View File

@@ -13,7 +13,7 @@ def test_resize_observation(env_id, shape):
assert isinstance(env.observation_space, spaces.Box)
assert env.observation_space.shape[-1] == 3
obs = env.reset()
obs, _ = env.reset()
if isinstance(shape, int):
assert env.observation_space.shape[:2] == (shape, shape)
assert obs.shape == (shape, shape, 3)

View File

@@ -14,8 +14,8 @@ def test_time_aware_observation(env_id):
assert isinstance(wrapped_env.observation_space, spaces.Box)
assert wrapped_env.observation_space.shape[0] == env.observation_space.shape[0] + 1
obs = env.reset()
wrapped_obs = wrapped_env.reset()
obs, info = env.reset()
wrapped_obs, wrapped_obs_info = wrapped_env.reset()
assert wrapped_env.t == 0.0
assert wrapped_obs[-1] == 0.0
assert wrapped_obs.shape[0] == obs.shape[0] + 1
@@ -30,7 +30,7 @@ def test_time_aware_observation(env_id):
assert wrapped_obs[-1] == 2.0
assert wrapped_obs.shape[0] == obs.shape[0] + 1
wrapped_obs = wrapped_env.reset()
wrapped_obs, wrapped_obs_info = wrapped_env.reset()
assert wrapped_env.t == 0.0
assert wrapped_obs[-1] == 0.0
assert wrapped_obs.shape[0] == obs.shape[0] + 1

View File

@@ -9,13 +9,7 @@ def test_time_limit_reset_info():
env = gym.make("CartPole-v1", disable_env_checker=True)
env = TimeLimit(env)
ob_space = env.observation_space
obs = env.reset()
assert ob_space.contains(obs)
del obs
obs = env.reset(return_info=False)
assert ob_space.contains(obs)
del obs
obs, info = env.reset(return_info=True)
obs, info = env.reset()
assert ob_space.contains(obs)
assert isinstance(info, dict)

View File

@@ -15,9 +15,10 @@ def test_transform_observation(env_id):
gym.make(env_id, disable_env_checker=True), lambda obs: affine_transform(obs)
)
obs = env.reset(seed=0)
wrapped_obs = wrapped_env.reset(seed=0)
obs, info = env.reset(seed=0)
wrapped_obs, wrapped_obs_info = wrapped_env.reset(seed=0)
assert np.allclose(wrapped_obs, affine_transform(obs))
assert isinstance(wrapped_obs_info, dict)
action = env.action_space.sample()
obs, reward, done, _ = env.step(action)

View File

@@ -23,7 +23,7 @@ def test_info_to_list():
env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True)
wrapped_env = VectorListInfo(env_to_wrap)
wrapped_env.action_space.seed(SEED)
_, info = wrapped_env.reset(seed=SEED, return_info=True)
_, info = wrapped_env.reset(seed=SEED)
assert isinstance(info, list)
assert len(info) == NUM_ENVS
@@ -40,7 +40,7 @@ def test_info_to_list():
def test_info_to_list_statistics():
env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True)
wrapped_env = VectorListInfo(RecordEpisodeStatistics(env_to_wrap))
_, info = wrapped_env.reset(seed=SEED, return_info=True)
_, info = wrapped_env.reset(seed=SEED)
wrapped_env.action_space.seed(SEED)
assert isinstance(info, list)
assert len(info) == NUM_ENVS