mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-29 17:45:07 +00:00
Removing return_info argument to env.reset() and deprecated env.seed() function (reset now always returns info) (#2962)
* removed return_info, made info dict mandatory in reset * tenatively removed deprecated seed api for environments * added more info type checks to wrapper tests * formatting/style compliance * addressed some comments * polish to address review * fixed tests after merge, and added a test of the return_info deprecation assertion if found in reset signature * some organization of env_checker tests, reverted a probably merge error * added deprecation check for seed function in env * updated docstring * removed debug prints, tweaked test_check_seed_deprecation * changed return_info deprecation check from assertion to warning * fixes to vector envs, now should be correctly structured * added some explanation and typehints for mockup depcreated return info reset function * re-removed seed function from vector envs * added explanation to _reset_return_info_type and changed the return statement
This commit is contained in:
@@ -23,14 +23,14 @@ The Gym API's API models environments as simple Python `env` classes. Creating e
|
||||
```python
|
||||
import gym
|
||||
env = gym.make("CartPole-v1")
|
||||
observation, info = env.reset(seed=42, return_info=True)
|
||||
observation, info = env.reset(seed=42)
|
||||
|
||||
for _ in range(1000):
|
||||
action = env.action_space.sample()
|
||||
observation, reward, done, info = env.step(action)
|
||||
|
||||
if done:
|
||||
observation, info = env.reset(return_info=True)
|
||||
observation, info = env.reset()
|
||||
env.close()
|
||||
```
|
||||
|
||||
|
53
gym/core.py
53
gym/core.py
@@ -41,11 +41,10 @@ class Env(Generic[ObsType, ActType]):
|
||||
The main API methods that users of this class need to know are:
|
||||
|
||||
- :meth:`step` - Takes a step in the environment using an action returning the next observation, reward,
|
||||
if the environment terminated and more information.
|
||||
- :meth:`reset` - Resets the environment to an initial state, returning the initial observation.
|
||||
if the environment terminated and observation information.
|
||||
- :meth:`reset` - Resets the environment to an initial state, returning the initial observation and observation information.
|
||||
- :meth:`render` - Renders the environment observation with modes depending on the output
|
||||
- :meth:`close` - Closes the environment, important for rendering where pygame is imported
|
||||
- :meth:`seed` - Seeds the environment's random number generator, :deprecated: in favor of `Env.reset(seed=seed)`.
|
||||
|
||||
And set the following attributes:
|
||||
|
||||
@@ -124,9 +123,8 @@ class Env(Generic[ObsType, ActType]):
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||
) -> Tuple[ObsType, dict]:
|
||||
"""Resets the environment to an initial state and returns the initial observation.
|
||||
|
||||
This method can reset the environment's random number generator(s) if ``seed`` is an integer or
|
||||
@@ -143,8 +141,6 @@ class Env(Generic[ObsType, ActType]):
|
||||
If you pass an integer, the PRNG will be reset even if it already exists.
|
||||
Usually, you want to pass an integer *right after the environment has been initialized and then never again*.
|
||||
Please refer to the minimal example above to see this paradigm in action.
|
||||
return_info (bool): If true, return additional information along with initial observation.
|
||||
This info should be analogous to the info returned in :meth:`step`
|
||||
options (optional dict): Additional information to specify how the environment is reset (optional,
|
||||
depending on the specific environment)
|
||||
|
||||
@@ -152,8 +148,7 @@ class Env(Generic[ObsType, ActType]):
|
||||
Returns:
|
||||
observation (object): Observation of the initial state. This will be an element of :attr:`observation_space`
|
||||
(typically a numpy array) and is analogous to the observation returned by :meth:`step`.
|
||||
info (optional dictionary): This will *only* be returned if ``return_info=True`` is passed.
|
||||
It contains auxiliary information complementing ``observation``. This dictionary should be analogous to
|
||||
info (dictionary): This dictionary contains auxiliary information complementing ``observation``. It should be analogous to
|
||||
the ``info`` returned by :meth:`step`.
|
||||
"""
|
||||
# Initialize the RNG if the seed is manually passed
|
||||
@@ -193,33 +188,6 @@ class Env(Generic[ObsType, ActType]):
|
||||
"""
|
||||
pass
|
||||
|
||||
def seed(self, seed=None):
|
||||
""":deprecated: function that sets the seed for the environment's random number generator(s).
|
||||
|
||||
Use `env.reset(seed=seed)` as the new API for setting the seed of the environment.
|
||||
|
||||
Note:
|
||||
Some environments use multiple pseudorandom number generators.
|
||||
We want to capture all such seeds used in order to ensure that
|
||||
there aren't accidental correlations between multiple generators.
|
||||
|
||||
Args:
|
||||
seed(Optional int): The seed value for the random number generator
|
||||
|
||||
Returns:
|
||||
seeds (List[int]): Returns the list of seeds used in this environment's random
|
||||
number generators. The first value in the list should be the
|
||||
"main" seed, or the value which a reproducer should pass to
|
||||
'seed'. Often, the main seed equals the provided 'seed', but
|
||||
this won't be true `if seed=None`, for example.
|
||||
"""
|
||||
deprecation(
|
||||
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
|
||||
"Please use `env.reset(seed=seed)` instead."
|
||||
)
|
||||
self._np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
@property
|
||||
def unwrapped(self) -> "Env":
|
||||
"""Returns the base non-wrapped environment.
|
||||
@@ -370,7 +338,7 @@ class Wrapper(Env[ObsType, ActType]):
|
||||
|
||||
return step_api_compatibility(self.env.step(action), self.new_step_api)
|
||||
|
||||
def reset(self, **kwargs) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||
def reset(self, **kwargs) -> Tuple[ObsType, dict]:
|
||||
"""Resets the environment with kwargs."""
|
||||
return self.env.reset(**kwargs)
|
||||
|
||||
@@ -384,10 +352,6 @@ class Wrapper(Env[ObsType, ActType]):
|
||||
"""Closes the environment."""
|
||||
return self.env.close()
|
||||
|
||||
def seed(self, seed=None):
|
||||
"""Seeds the environment."""
|
||||
return self.env.seed(seed)
|
||||
|
||||
def __str__(self):
|
||||
"""Returns the wrapper name and the unwrapped environment string."""
|
||||
return f"<{type(self).__name__}{self.env}>"
|
||||
@@ -432,11 +396,8 @@ class ObservationWrapper(Wrapper):
|
||||
|
||||
def reset(self, **kwargs):
|
||||
"""Resets the environment, returning a modified observation using :meth:`self.observation`."""
|
||||
if kwargs.get("return_info", False):
|
||||
obs, info = self.env.reset(**kwargs)
|
||||
return self.observation(obs), info
|
||||
else:
|
||||
return self.observation(self.env.reset(**kwargs))
|
||||
obs, info = self.env.reset(**kwargs)
|
||||
return self.observation(obs), info
|
||||
|
||||
def step(self, action):
|
||||
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""
|
||||
|
@@ -428,7 +428,6 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
@@ -514,10 +513,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
|
||||
self.lidar = [LidarCallback() for _ in range(10)]
|
||||
self.renderer.reset()
|
||||
if not return_info:
|
||||
return self.step(np.array([0, 0, 0, 0]))[0]
|
||||
else:
|
||||
return self.step(np.array([0, 0, 0, 0]))[0], {}
|
||||
return self.step(np.array([0, 0, 0, 0]))[0], {}
|
||||
|
||||
def step(self, action: np.ndarray):
|
||||
assert self.hull is not None
|
||||
|
@@ -483,7 +483,6 @@ class CarRacing(gym.Env, EzPickle):
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
@@ -519,10 +518,7 @@ class CarRacing(gym.Env, EzPickle):
|
||||
self.car = Car(self.world, *self.track[0][1:4])
|
||||
|
||||
self.renderer.reset()
|
||||
if not return_info:
|
||||
return self.step(None)[0]
|
||||
else:
|
||||
return self.step(None)[0], {}
|
||||
return self.step(None)[0], {}
|
||||
|
||||
def step(self, action: Union[np.ndarray, int]):
|
||||
assert self.car is not None
|
||||
|
@@ -305,7 +305,6 @@ class LunarLander(gym.Env, EzPickle):
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
@@ -413,10 +412,7 @@ class LunarLander(gym.Env, EzPickle):
|
||||
self.drawlist = [self.lander] + self.legs
|
||||
|
||||
self.renderer.reset()
|
||||
if not return_info:
|
||||
return self.step(np.array([0, 0]) if self.continuous else 0)[0]
|
||||
else:
|
||||
return self.step(np.array([0, 0]) if self.continuous else 0)[0], {}
|
||||
return self.step(np.array([0, 0]) if self.continuous else 0)[0], {}
|
||||
|
||||
def _create_particle(self, mass, x, y, ttl):
|
||||
p = self.world.CreateDynamicBody(
|
||||
@@ -774,7 +770,7 @@ def demo_heuristic_lander(env, seed=None, render=False):
|
||||
|
||||
total_reward = 0
|
||||
steps = 0
|
||||
s = env.reset(seed=seed)
|
||||
s, info = env.reset(seed=seed)
|
||||
while True:
|
||||
a = heuristic(env, s)
|
||||
s, r, terminated, truncated, info = step_api_compatibility(env.step(a), True)
|
||||
|
@@ -180,13 +180,7 @@ class AcrobotEnv(core.Env):
|
||||
self.action_space = spaces.Discrete(3)
|
||||
self.state = None
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None
|
||||
):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
super().reset(seed=seed)
|
||||
# Note that if you use custom reset bounds, it may lead to out-of-bound
|
||||
# state/observations.
|
||||
@@ -199,10 +193,7 @@ class AcrobotEnv(core.Env):
|
||||
|
||||
self.renderer.reset()
|
||||
self.renderer.render_step()
|
||||
if not return_info:
|
||||
return self._get_ob()
|
||||
else:
|
||||
return self._get_ob(), {}
|
||||
return self._get_ob(), {}
|
||||
|
||||
def step(self, a):
|
||||
s = self.state
|
||||
|
@@ -192,7 +192,6 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
@@ -205,10 +204,7 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
||||
self.steps_beyond_terminated = None
|
||||
self.renderer.reset()
|
||||
self.renderer.render_step()
|
||||
if not return_info:
|
||||
return np.array(self.state, dtype=np.float32)
|
||||
else:
|
||||
return np.array(self.state, dtype=np.float32), {}
|
||||
return np.array(self.state, dtype=np.float32), {}
|
||||
|
||||
def render(self):
|
||||
return self.renderer.get_renders()
|
||||
|
@@ -174,13 +174,7 @@ class Continuous_MountainCarEnv(gym.Env):
|
||||
self.renderer.render_step()
|
||||
return self.state, reward, terminated, False, {}
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None
|
||||
):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
super().reset(seed=seed)
|
||||
# Note that if you use custom reset bounds, it may lead to out-of-bound
|
||||
# state/observations.
|
||||
@@ -188,10 +182,7 @@ class Continuous_MountainCarEnv(gym.Env):
|
||||
self.state = np.array([self.np_random.uniform(low=low, high=high), 0])
|
||||
self.renderer.reset()
|
||||
self.renderer.render_step()
|
||||
if not return_info:
|
||||
return np.array(self.state, dtype=np.float32)
|
||||
else:
|
||||
return np.array(self.state, dtype=np.float32), {}
|
||||
return np.array(self.state, dtype=np.float32), {}
|
||||
|
||||
def _height(self, xs):
|
||||
return np.sin(3 * xs) * 0.45 + 0.55
|
||||
|
@@ -152,7 +152,6 @@ class MountainCarEnv(gym.Env):
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
@@ -162,10 +161,7 @@ class MountainCarEnv(gym.Env):
|
||||
self.state = np.array([self.np_random.uniform(low=low, high=high), 0])
|
||||
self.renderer.reset()
|
||||
self.renderer.render_step()
|
||||
if not return_info:
|
||||
return np.array(self.state, dtype=np.float32)
|
||||
else:
|
||||
return np.array(self.state, dtype=np.float32), {}
|
||||
return np.array(self.state, dtype=np.float32), {}
|
||||
|
||||
def _height(self, xs):
|
||||
return np.sin(3 * xs) * 0.45 + 0.55
|
||||
|
@@ -138,13 +138,7 @@ class PendulumEnv(gym.Env):
|
||||
self.renderer.render_step()
|
||||
return self._get_obs(), -costs, False, False, {}
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None
|
||||
):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
super().reset(seed=seed)
|
||||
if options is None:
|
||||
high = np.array([DEFAULT_X, DEFAULT_Y])
|
||||
@@ -162,10 +156,7 @@ class PendulumEnv(gym.Env):
|
||||
|
||||
self.renderer.reset()
|
||||
self.renderer.render_step()
|
||||
if not return_info:
|
||||
return self._get_obs()
|
||||
else:
|
||||
return self._get_obs(), {}
|
||||
return self._get_obs(), {}
|
||||
|
||||
def _get_obs(self):
|
||||
theta, thetadot = self.state
|
||||
|
@@ -142,7 +142,6 @@ class BaseMujocoEnv(gym.Env):
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
@@ -152,10 +151,7 @@ class BaseMujocoEnv(gym.Env):
|
||||
ob = self.reset_model()
|
||||
self.renderer.reset()
|
||||
self.renderer.render_step()
|
||||
if not return_info:
|
||||
return ob
|
||||
else:
|
||||
return ob, {}
|
||||
return ob, {}
|
||||
|
||||
def set_state(self, qpos, qvel):
|
||||
"""
|
||||
|
@@ -167,7 +167,6 @@ class BlackjackEnv(gym.Env):
|
||||
def reset(
|
||||
self,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
@@ -189,10 +188,7 @@ class BlackjackEnv(gym.Env):
|
||||
self.renderer.reset()
|
||||
self.renderer.render_step()
|
||||
|
||||
if not return_info:
|
||||
return self._get_obs()
|
||||
else:
|
||||
return self._get_obs(), {}
|
||||
return self._get_obs(), {}
|
||||
|
||||
def render(self):
|
||||
return self.renderer.get_renders()
|
||||
|
@@ -149,22 +149,14 @@ class CliffWalkingEnv(Env):
|
||||
self.renderer.render_step()
|
||||
return (int(s), r, t, False, {"prob": p})
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None
|
||||
):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
super().reset(seed=seed)
|
||||
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
|
||||
self.lastaction = None
|
||||
self.renderer.reset()
|
||||
self.renderer.render_step()
|
||||
if not return_info:
|
||||
return int(self.s)
|
||||
else:
|
||||
return int(self.s), {"prob": 1}
|
||||
|
||||
return int(self.s), {"prob": 1}
|
||||
|
||||
def render(self):
|
||||
return self.renderer.get_renders()
|
||||
|
@@ -256,7 +256,6 @@ class FrozenLakeEnv(Env):
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
@@ -266,10 +265,7 @@ class FrozenLakeEnv(Env):
|
||||
self.renderer.reset()
|
||||
self.renderer.render_step()
|
||||
|
||||
if not return_info:
|
||||
return int(self.s)
|
||||
else:
|
||||
return int(self.s), {"prob": 1}
|
||||
return int(self.s), {"prob": 1}
|
||||
|
||||
def render(self):
|
||||
return self.renderer.get_renders()
|
||||
|
@@ -89,7 +89,7 @@ class TaxiEnv(Env):
|
||||
|
||||
### Info
|
||||
|
||||
``step`` and ``reset(return_info=True)`` will return an info dictionary that contains "p" and "action_mask" containing
|
||||
``step`` and ``reset()`` will return an info dictionary that contains "p" and "action_mask" containing
|
||||
the probability that the state is taken and a mask of what actions will result in a change of state to speed up training.
|
||||
|
||||
As Taxi's initial state is a stochastic, the "p" key represents the probability of the
|
||||
@@ -266,7 +266,6 @@ class TaxiEnv(Env):
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
@@ -275,10 +274,8 @@ class TaxiEnv(Env):
|
||||
self.taxi_orientation = 0
|
||||
self.renderer.reset()
|
||||
self.renderer.render_step()
|
||||
if not return_info:
|
||||
return int(self.s)
|
||||
else:
|
||||
return int(self.s), {"prob": 1.0, "action_mask": self.action_mask(self.s)}
|
||||
|
||||
return int(self.s), {"prob": 1.0, "action_mask": self.action_mask(self.s)}
|
||||
|
||||
def render(self):
|
||||
return self.renderer.get_renders()
|
||||
|
@@ -73,7 +73,7 @@ def check_reset_seed(env: gym.Env):
|
||||
and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD
|
||||
):
|
||||
try:
|
||||
obs_1 = env.reset(seed=123)
|
||||
obs_1, info = env.reset(seed=123)
|
||||
assert (
|
||||
obs_1 in env.observation_space
|
||||
), "The observation returned by `env.reset(seed=123)` is not within the observation space."
|
||||
@@ -85,7 +85,7 @@ def check_reset_seed(env: gym.Env):
|
||||
env.unwrapped._np_random # pyright: ignore [reportPrivateUsage]
|
||||
)
|
||||
|
||||
obs_2 = env.reset(seed=123)
|
||||
obs_2, info = env.reset(seed=123)
|
||||
assert (
|
||||
obs_2 in env.observation_space
|
||||
), "The observation returned by `env.reset(seed=123)` is not within the observation space."
|
||||
@@ -98,7 +98,7 @@ def check_reset_seed(env: gym.Env):
|
||||
== seed_123_rng.bit_generator.state
|
||||
), "Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random generates are not same when the same seeds are passed to `env.reset`."
|
||||
|
||||
obs_3 = env.reset(seed=456)
|
||||
obs_3, info = env.reset(seed=456)
|
||||
assert (
|
||||
obs_3 in env.observation_space
|
||||
), "The observation returned by `env.reset(seed=456)` is not within the observation space."
|
||||
@@ -126,53 +126,6 @@ def check_reset_seed(env: gym.Env):
|
||||
)
|
||||
|
||||
|
||||
def check_reset_info(env: gym.Env):
|
||||
"""Checks that :meth:`reset` supports the ``return_info`` keyword.
|
||||
|
||||
Args:
|
||||
env: The environment to check
|
||||
|
||||
Raises:
|
||||
AssertionError: The environment cannot be reset with `return_info=True`,
|
||||
even though `return_info` or `kwargs` appear in the signature.
|
||||
"""
|
||||
signature = inspect.signature(env.reset)
|
||||
if "return_info" in signature.parameters or (
|
||||
"kwargs" in signature.parameters
|
||||
and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD
|
||||
):
|
||||
try:
|
||||
obs = env.reset(return_info=False)
|
||||
assert (
|
||||
obs in env.observation_space
|
||||
), "The value returned by `env.reset(return_info=True)` is not within the observation space."
|
||||
|
||||
result = env.reset(return_info=True)
|
||||
assert isinstance(
|
||||
result, tuple
|
||||
), f"Calling the reset method with `return_info=True` did not return a tuple, actual type: {type(result)}"
|
||||
assert (
|
||||
len(result) == 2
|
||||
), f"Calling the reset method with `return_info=True` did not return a 2-tuple, actual length: {len(result)}"
|
||||
|
||||
obs, info = result
|
||||
assert (
|
||||
obs in env.observation_space
|
||||
), "The first element returned by `env.reset(return_info=True)` is not within the observation space."
|
||||
assert isinstance(
|
||||
info, dict
|
||||
), f"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: {type(info)}"
|
||||
except TypeError as e:
|
||||
raise AssertionError(
|
||||
"The environment cannot be reset with `return_info=True`, even though `return_info` or `kwargs` appear in the signature. "
|
||||
f"This should never happen, please report this issue. The error was: {e}"
|
||||
)
|
||||
else:
|
||||
raise gym.error.Error(
|
||||
"The `reset` method does not provide a `return_info` or `**kwargs` keyword argument."
|
||||
)
|
||||
|
||||
|
||||
def check_reset_options(env: gym.Env):
|
||||
"""Check that the environment can be reset with options.
|
||||
|
||||
@@ -201,6 +154,64 @@ def check_reset_options(env: gym.Env):
|
||||
)
|
||||
|
||||
|
||||
def check_reset_return_info_deprecation(env: gym.Env):
|
||||
"""Makes sure support for deprecated `return_info` argument is dropped.
|
||||
|
||||
Args:
|
||||
env: The environment to check
|
||||
Raises:
|
||||
UserWarning
|
||||
"""
|
||||
signature = inspect.signature(env.reset)
|
||||
if "return_info" in signature.parameters:
|
||||
logger.warn(
|
||||
"`return_info` is deprecated as an optional argument to `reset`. `reset`"
|
||||
"should now always return `obs, info` where `obs` is an observation, and `info` is a dictionary"
|
||||
"containing additional information."
|
||||
)
|
||||
|
||||
|
||||
def check_seed_deprecation(env: gym.Env):
|
||||
"""Makes sure support for deprecated function `seed` is dropped.
|
||||
|
||||
Args:
|
||||
env: The environment to check
|
||||
Raises:
|
||||
UserWarning
|
||||
"""
|
||||
seed_fn = getattr(env, "seed", None)
|
||||
if callable(seed_fn):
|
||||
logger.warn(
|
||||
"Official support for the `seed` function is dropped. "
|
||||
"Standard practice is to reset gym environments using `env.reset(seed=<desired seed>)`"
|
||||
)
|
||||
|
||||
|
||||
def check_reset_return_type(env: gym.Env):
|
||||
"""Checks that :meth:`reset` correctly returns a tuple of the form `(obs , info)`.
|
||||
|
||||
Args:
|
||||
env: The environment to check
|
||||
Raises:
|
||||
AssertionError depending on spec violation
|
||||
"""
|
||||
result = env.reset()
|
||||
assert isinstance(
|
||||
result, tuple
|
||||
), f"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `{type(result)}`"
|
||||
assert (
|
||||
len(result) == 2
|
||||
), f"Calling the reset method did not return a 2-tuple, actual length: {len(result)}"
|
||||
|
||||
obs, info = result
|
||||
assert (
|
||||
obs in env.observation_space
|
||||
), "The first element returned by `env.reset()` is not within the observation space."
|
||||
assert isinstance(
|
||||
info, dict
|
||||
), f"The second element returned by `env.reset()` was not a dictionary, actual type: {type(info)}"
|
||||
|
||||
|
||||
def check_space_limit(space, space_type: str):
|
||||
"""Check the space limit for only the Box space as a test that only runs as part of `check_env`."""
|
||||
if isinstance(space, spaces.Box):
|
||||
@@ -279,9 +290,11 @@ def check_env(env: gym.Env, warn: bool = None, skip_render_check: bool = False):
|
||||
check_space_limit(env.observation_space, "observation")
|
||||
|
||||
# ==== Check the reset method ====
|
||||
check_seed_deprecation(env)
|
||||
check_reset_return_info_deprecation(env)
|
||||
check_reset_return_type(env)
|
||||
check_reset_seed(env)
|
||||
check_reset_options(env)
|
||||
check_reset_info(env)
|
||||
|
||||
# ============ Check the returned values ===============
|
||||
env_reset_passive_checker(env)
|
||||
|
@@ -183,14 +183,6 @@ def env_reset_passive_checker(env, **kwargs):
|
||||
f"Actual default: {seed_param}"
|
||||
)
|
||||
|
||||
if "return_info" not in signature.parameters and not (
|
||||
"kwargs" in signature.parameters
|
||||
and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD
|
||||
):
|
||||
logger.warn(
|
||||
"Future gym versions will require that `Env.reset` can be passed `return_info` to return information from the environment resetting."
|
||||
)
|
||||
|
||||
if "options" not in signature.parameters and "kwargs" not in signature.parameters:
|
||||
logger.warn(
|
||||
"Future gym versions will require that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information."
|
||||
@@ -198,21 +190,17 @@ def env_reset_passive_checker(env, **kwargs):
|
||||
|
||||
# Checks the result of env.reset with kwargs
|
||||
result = env.reset(**kwargs)
|
||||
if kwargs.get("return_info", False) is True:
|
||||
assert isinstance(
|
||||
result, tuple
|
||||
), f"The result returned by `env.reset(return_info=True)` was not a tuple, actual type: {type(result)}"
|
||||
assert (
|
||||
len(result) == 2
|
||||
), f"The length of the result returned by `env.reset(return_info=True)` is not 2, actual length: {len(result)}"
|
||||
obs, info = result
|
||||
assert isinstance(
|
||||
info, dict
|
||||
), f"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: {type(info)}"
|
||||
else:
|
||||
obs = result
|
||||
|
||||
if not isinstance(result, tuple):
|
||||
logger.warn(
|
||||
f"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `{type(result)}`"
|
||||
)
|
||||
|
||||
obs, info = result
|
||||
check_obs(obs, env.observation_space, "reset")
|
||||
assert isinstance(
|
||||
info, dict
|
||||
), f"The second element returned by `env.reset()` was not a dictionary, actual type: {type(info)}"
|
||||
return result
|
||||
|
||||
|
||||
|
@@ -171,38 +171,9 @@ class AsyncVectorEnv(VectorEnv):
|
||||
self._state = AsyncState.DEFAULT
|
||||
self._check_spaces()
|
||||
|
||||
def seed(self, seed=None):
|
||||
"""Seeds the vector environments.
|
||||
|
||||
Args:
|
||||
seed: The seeds use with the environments
|
||||
|
||||
Raises:
|
||||
AlreadyPendingCallError: Calling `seed` while waiting for a pending call to complete
|
||||
"""
|
||||
super().seed(seed=seed)
|
||||
self._assert_is_running()
|
||||
if seed is None:
|
||||
seed = [None for _ in range(self.num_envs)]
|
||||
if isinstance(seed, int):
|
||||
seed = [seed + i for i in range(self.num_envs)]
|
||||
assert len(seed) == self.num_envs
|
||||
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
raise AlreadyPendingCallError(
|
||||
f"Calling `seed` while waiting for a pending call to `{self._state.value}` to complete.",
|
||||
self._state.value,
|
||||
)
|
||||
|
||||
for pipe, seed in zip(self.parent_pipes, seed):
|
||||
pipe.send(("seed", seed))
|
||||
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
||||
self._raise_if_errors(successes)
|
||||
|
||||
def reset_async(
|
||||
self,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
"""Send calls to the :obj:`reset` methods of the sub-environments.
|
||||
@@ -211,7 +182,6 @@ class AsyncVectorEnv(VectorEnv):
|
||||
|
||||
Args:
|
||||
seed: List of seeds for each environment
|
||||
return_info: If to return information
|
||||
options: The reset option
|
||||
|
||||
Raises:
|
||||
@@ -238,8 +208,6 @@ class AsyncVectorEnv(VectorEnv):
|
||||
single_kwargs = {}
|
||||
if single_seed is not None:
|
||||
single_kwargs["seed"] = single_seed
|
||||
if return_info:
|
||||
single_kwargs["return_info"] = return_info
|
||||
if options is not None:
|
||||
single_kwargs["options"] = options
|
||||
|
||||
@@ -250,7 +218,6 @@ class AsyncVectorEnv(VectorEnv):
|
||||
self,
|
||||
timeout: Optional[Union[int, float]] = None,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
) -> Union[ObsType, Tuple[ObsType, List[dict]]]:
|
||||
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
|
||||
@@ -258,7 +225,6 @@ class AsyncVectorEnv(VectorEnv):
|
||||
Args:
|
||||
timeout: Number of seconds before the call to `reset_wait` times out. If `None`, the call to `reset_wait` never times out.
|
||||
seed: ignored
|
||||
return_info: If to return information
|
||||
options: ignored
|
||||
|
||||
Returns:
|
||||
@@ -286,27 +252,17 @@ class AsyncVectorEnv(VectorEnv):
|
||||
self._raise_if_errors(successes)
|
||||
self._state = AsyncState.DEFAULT
|
||||
|
||||
if return_info:
|
||||
infos = {}
|
||||
results, info_data = zip(*results)
|
||||
for i, info in enumerate(info_data):
|
||||
infos = self._add_info(infos, info, i)
|
||||
infos = {}
|
||||
results, info_data = zip(*results)
|
||||
for i, info in enumerate(info_data):
|
||||
infos = self._add_info(infos, info, i)
|
||||
|
||||
if not self.shared_memory:
|
||||
self.observations = concatenate(
|
||||
self.single_observation_space, results, self.observations
|
||||
)
|
||||
if not self.shared_memory:
|
||||
self.observations = concatenate(
|
||||
self.single_observation_space, results, self.observations
|
||||
)
|
||||
|
||||
return (
|
||||
deepcopy(self.observations) if self.copy else self.observations
|
||||
), infos
|
||||
else:
|
||||
if not self.shared_memory:
|
||||
self.observations = concatenate(
|
||||
self.single_observation_space, results, self.observations
|
||||
)
|
||||
|
||||
return deepcopy(self.observations) if self.copy else self.observations
|
||||
return (deepcopy(self.observations) if self.copy else self.observations), infos
|
||||
|
||||
def step_async(self, actions: np.ndarray):
|
||||
"""Send the calls to :obj:`step` to each sub-environment.
|
||||
@@ -606,12 +562,8 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
||||
while True:
|
||||
command, data = pipe.recv()
|
||||
if command == "reset":
|
||||
if "return_info" in data and data["return_info"] is True:
|
||||
observation, info = env.reset(**data)
|
||||
pipe.send(((observation, info), True))
|
||||
else:
|
||||
observation = env.reset(**data)
|
||||
pipe.send((observation, True))
|
||||
observation, info = env.reset(**data)
|
||||
pipe.send(((observation, info), True))
|
||||
|
||||
elif command == "step":
|
||||
(
|
||||
@@ -622,8 +574,9 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
||||
info,
|
||||
) = step_api_compatibility(env.step(data), True)
|
||||
if terminated or truncated:
|
||||
info["final_observation"] = observation
|
||||
observation = env.reset()
|
||||
old_observation = observation
|
||||
observation, info = env.reset()
|
||||
info["final_observation"] = old_observation
|
||||
pipe.send(((observation, reward, terminated, truncated, info), True))
|
||||
elif command == "seed":
|
||||
env.seed(data)
|
||||
@@ -676,18 +629,12 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
|
||||
while True:
|
||||
command, data = pipe.recv()
|
||||
if command == "reset":
|
||||
if "return_info" in data and data["return_info"] is True:
|
||||
observation, info = env.reset(**data)
|
||||
write_to_shared_memory(
|
||||
observation_space, index, observation, shared_memory
|
||||
)
|
||||
pipe.send(((None, info), True))
|
||||
else:
|
||||
observation = env.reset(**data)
|
||||
write_to_shared_memory(
|
||||
observation_space, index, observation, shared_memory
|
||||
)
|
||||
pipe.send((None, True))
|
||||
observation, info = env.reset(**data)
|
||||
write_to_shared_memory(
|
||||
observation_space, index, observation, shared_memory
|
||||
)
|
||||
pipe.send(((None, info), True))
|
||||
|
||||
elif command == "step":
|
||||
(
|
||||
observation,
|
||||
@@ -697,8 +644,10 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
|
||||
info,
|
||||
) = step_api_compatibility(env.step(data), True)
|
||||
if terminated or truncated:
|
||||
info["final_observation"] = observation
|
||||
observation = env.reset()
|
||||
old_observation = observation
|
||||
observation, info = env.reset()
|
||||
info["final_observation"] = old_observation
|
||||
|
||||
write_to_shared_memory(
|
||||
observation_space, index, observation, shared_memory
|
||||
)
|
||||
|
@@ -93,14 +93,12 @@ class SyncVectorEnv(VectorEnv):
|
||||
def reset_wait(
|
||||
self,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
|
||||
|
||||
Args:
|
||||
seed: The reset environment seed
|
||||
return_info: If to return information
|
||||
options: Option information for the environment reset
|
||||
|
||||
Returns:
|
||||
@@ -123,26 +121,15 @@ class SyncVectorEnv(VectorEnv):
|
||||
kwargs["seed"] = single_seed
|
||||
if options is not None:
|
||||
kwargs["options"] = options
|
||||
if return_info is True:
|
||||
kwargs["return_info"] = return_info
|
||||
|
||||
if not return_info:
|
||||
observation = env.reset(**kwargs)
|
||||
observations.append(observation)
|
||||
else:
|
||||
observation, info = env.reset(**kwargs)
|
||||
observations.append(observation)
|
||||
infos = self._add_info(infos, info, i)
|
||||
observation, info = env.reset(**kwargs)
|
||||
observations.append(observation)
|
||||
infos = self._add_info(infos, info, i)
|
||||
|
||||
self.observations = concatenate(
|
||||
self.single_observation_space, observations, self.observations
|
||||
)
|
||||
if not return_info:
|
||||
return deepcopy(self.observations) if self.copy else self.observations
|
||||
else:
|
||||
return (
|
||||
deepcopy(self.observations) if self.copy else self.observations
|
||||
), infos
|
||||
return (deepcopy(self.observations) if self.copy else self.observations), infos
|
||||
|
||||
def step_async(self, actions):
|
||||
"""Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version."""
|
||||
@@ -164,8 +151,9 @@ class SyncVectorEnv(VectorEnv):
|
||||
info,
|
||||
) = step_api_compatibility(env.step(action), True)
|
||||
if self._terminateds[i] or self._truncateds[i]:
|
||||
info["final_observation"] = observation
|
||||
observation = env.reset()
|
||||
old_observation = observation
|
||||
observation, info = env.reset()
|
||||
info["final_observation"] = old_observation
|
||||
observations.append(observation)
|
||||
infos = self._add_info(infos, info, i)
|
||||
self.observations = concatenate(
|
||||
|
@@ -60,7 +60,6 @@ class VectorEnv(gym.Env):
|
||||
def reset_async(
|
||||
self,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
"""Reset the sub-environments asynchronously.
|
||||
@@ -70,7 +69,6 @@ class VectorEnv(gym.Env):
|
||||
|
||||
Args:
|
||||
seed: The reset seed
|
||||
return_info: If to return info
|
||||
options: Reset options
|
||||
"""
|
||||
pass
|
||||
@@ -78,7 +76,6 @@ class VectorEnv(gym.Env):
|
||||
def reset_wait(
|
||||
self,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
"""Retrieves the results of a :meth:`reset_async` call.
|
||||
@@ -87,7 +84,6 @@ class VectorEnv(gym.Env):
|
||||
|
||||
Args:
|
||||
seed: The reset seed
|
||||
return_info: If to return info
|
||||
options: Reset options
|
||||
|
||||
Returns:
|
||||
@@ -102,21 +98,19 @@ class VectorEnv(gym.Env):
|
||||
self,
|
||||
*,
|
||||
seed: Optional[Union[int, List[int]]] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
):
|
||||
"""Reset all parallel environments and return a batch of initial observations.
|
||||
|
||||
Args:
|
||||
seed: The environment reset seeds
|
||||
return_info: If to return the info
|
||||
options: If to return the options
|
||||
|
||||
Returns:
|
||||
A batch of observations from the vectorized environment.
|
||||
"""
|
||||
self.reset_async(seed=seed, return_info=return_info, options=options)
|
||||
return self.reset_wait(seed=seed, return_info=return_info, options=options)
|
||||
self.reset_async(seed=seed, options=options)
|
||||
return self.reset_wait(seed=seed, options=options)
|
||||
|
||||
def step_async(self, actions):
|
||||
"""Asynchronously performs steps in the sub-environments.
|
||||
@@ -220,21 +214,6 @@ class VectorEnv(gym.Env):
|
||||
self.close_extras(**kwargs)
|
||||
self.closed = True
|
||||
|
||||
def seed(self, seed=None):
|
||||
"""Set the random seed in all parallel environments.
|
||||
|
||||
Args:
|
||||
seed: Random seed for each parallel environment. If ``seed`` is a list of
|
||||
length ``num_envs``, then the items of the list are chosen as random
|
||||
seeds. If ``seed`` is an int, then each parallel environment uses the random
|
||||
seed ``seed + n``, where ``n`` is the index of the parallel environment
|
||||
(between ``0`` and ``num_envs - 1``).
|
||||
"""
|
||||
deprecation(
|
||||
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
|
||||
"Please use `env.reset(seed=seed) instead in VectorEnvs."
|
||||
)
|
||||
|
||||
def _add_info(self, infos: dict, info: dict, env_num: int) -> dict:
|
||||
"""Add env info to the info dictionary of the vectorized environment.
|
||||
|
||||
@@ -339,9 +318,6 @@ class VectorEnvWrapper(VectorEnv):
|
||||
def close_extras(self, **kwargs):
|
||||
return self.env.close_extras(**kwargs)
|
||||
|
||||
def seed(self, seed=None):
|
||||
return self.env.seed(seed)
|
||||
|
||||
def call(self, name, *args, **kwargs):
|
||||
return self.env.call(name, *args, **kwargs)
|
||||
|
||||
|
@@ -151,11 +151,7 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
def reset(self, **kwargs):
|
||||
"""Resets the environment using preprocessing."""
|
||||
# NoopReset
|
||||
if kwargs.get("return_info", False):
|
||||
_, reset_info = self.env.reset(**kwargs)
|
||||
else:
|
||||
_ = self.env.reset(**kwargs)
|
||||
reset_info = {}
|
||||
_, reset_info = self.env.reset(**kwargs)
|
||||
|
||||
noops = (
|
||||
self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
|
||||
@@ -168,11 +164,7 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
)
|
||||
reset_info.update(step_info)
|
||||
if terminated or truncated:
|
||||
if kwargs.get("return_info", False):
|
||||
_, reset_info = self.env.reset(**kwargs)
|
||||
else:
|
||||
_ = self.env.reset(**kwargs)
|
||||
reset_info = {}
|
||||
_, reset_info = self.env.reset(**kwargs)
|
||||
|
||||
self.lives = self.ale.lives()
|
||||
if self.grayscale_obs:
|
||||
@@ -181,10 +173,7 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
self.ale.getScreenRGB(self.obs_buffer[0])
|
||||
self.obs_buffer[1].fill(0)
|
||||
|
||||
if kwargs.get("return_info", False):
|
||||
return self._get_obs(), reset_info
|
||||
else:
|
||||
return self._get_obs()
|
||||
return self._get_obs(), reset_info
|
||||
|
||||
def _get_obs(self):
|
||||
if self.frame_skip > 1: # more efficient in-place pooling
|
||||
|
@@ -48,7 +48,7 @@ class AutoResetWrapper(gym.Wrapper):
|
||||
|
||||
if terminated or truncated:
|
||||
|
||||
new_obs, new_info = self.env.reset(return_info=True)
|
||||
new_obs, new_info = self.env.reset()
|
||||
assert (
|
||||
"final_observation" not in new_info
|
||||
), 'info dict cannot contain key "final_observation" '
|
||||
|
@@ -191,14 +191,8 @@ class FrameStack(gym.ObservationWrapper):
|
||||
Returns:
|
||||
The stacked observations
|
||||
"""
|
||||
if kwargs.get("return_info", False):
|
||||
obs, info = self.env.reset(**kwargs)
|
||||
else:
|
||||
obs = self.env.reset(**kwargs)
|
||||
info = None # Unused
|
||||
obs, info = self.env.reset(**kwargs)
|
||||
|
||||
[self.frames.append(obs) for _ in range(self.num_stack)]
|
||||
|
||||
if kwargs.get("return_info", False):
|
||||
return self.observation(None), info
|
||||
else:
|
||||
return self.observation(None)
|
||||
return self.observation(None), info
|
||||
|
@@ -89,20 +89,12 @@ class NormalizeObservation(gym.core.Wrapper):
|
||||
|
||||
def reset(self, **kwargs):
|
||||
"""Resets the environment and normalizes the observation."""
|
||||
if kwargs.get("return_info", False):
|
||||
obs, info = self.env.reset(**kwargs)
|
||||
obs, info = self.env.reset(**kwargs)
|
||||
|
||||
if self.is_vector_env:
|
||||
return self.normalize(obs), info
|
||||
else:
|
||||
return self.normalize(np.array([obs]))[0], info
|
||||
if self.is_vector_env:
|
||||
return self.normalize(obs), info
|
||||
else:
|
||||
obs = self.env.reset(**kwargs)
|
||||
|
||||
if self.is_vector_env:
|
||||
return self.normalize(obs)
|
||||
else:
|
||||
return self.normalize(np.array([obs]))[0]
|
||||
return self.normalize(np.array([obs]))[0], info
|
||||
|
||||
def normalize(self, obs):
|
||||
"""Normalises the observation using the running mean and variance of the observations."""
|
||||
|
@@ -57,9 +57,6 @@ class VectorListInfo(gym.Wrapper):
|
||||
|
||||
def reset(self, **kwargs):
|
||||
"""Resets the environment using kwargs."""
|
||||
if not kwargs.get("return_info"):
|
||||
return self.env.reset(**kwargs)
|
||||
|
||||
obs, infos = self.env.reset(**kwargs)
|
||||
list_info = self._convert_info_to_list(infos)
|
||||
return obs, list_info
|
||||
|
@@ -139,7 +139,7 @@ def test_taxi_action_mask():
|
||||
def test_taxi_encode_decode():
|
||||
env = TaxiEnv()
|
||||
|
||||
state = env.reset()
|
||||
state, info = env.reset()
|
||||
for _ in range(100):
|
||||
assert (
|
||||
env.encode(*env.decode(state)) == state
|
||||
|
@@ -82,8 +82,8 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
|
||||
env_1 = env_spec.make(disable_env_checker=True)
|
||||
env_2 = env_spec.make(disable_env_checker=True)
|
||||
|
||||
initial_obs_1 = env_1.reset(seed=SEED)
|
||||
initial_obs_2 = env_2.reset(seed=SEED)
|
||||
initial_obs_1, initial_info_1 = env_1.reset(seed=SEED)
|
||||
initial_obs_2, initial_info_2 = env_2.reset(seed=SEED)
|
||||
assert_equals(initial_obs_1, initial_obs_2)
|
||||
|
||||
env_1.action_space.seed(SEED)
|
||||
|
@@ -17,8 +17,8 @@ def verify_environments_match(
|
||||
old_env = envs.make(old_env_id, disable_env_checker=True)
|
||||
new_env = envs.make(new_env_id, disable_env_checker=True)
|
||||
|
||||
old_reset_obs = old_env.reset(seed=seed)
|
||||
new_reset_obs = new_env.reset(seed=seed)
|
||||
old_reset_obs, old_info = old_env.reset(seed=seed)
|
||||
new_reset_obs, new_info = new_env.reset(seed=seed)
|
||||
|
||||
np.testing.assert_allclose(old_reset_obs, new_reset_obs)
|
||||
|
||||
@@ -56,7 +56,7 @@ EXCLUDE_POS_FROM_OBS = [
|
||||
def test_obs_space_mujoco_environments(env_spec: EnvSpec):
|
||||
"""Check that the returned observations are contained in the observation space of the environment"""
|
||||
env = env_spec.make(disable_env_checker=True)
|
||||
reset_obs = env.reset()
|
||||
reset_obs, info = env.reset()
|
||||
assert env.observation_space.contains(
|
||||
reset_obs
|
||||
), f"Obseravtion returned by reset() of {env_spec.id} is not contained in the default observation space {env.observation_space}."
|
||||
@@ -73,7 +73,7 @@ def test_obs_space_mujoco_environments(env_spec: EnvSpec):
|
||||
env = env_spec.make(
|
||||
disable_env_checker=True, exclude_current_positions_from_observation=False
|
||||
)
|
||||
reset_obs = env.reset()
|
||||
reset_obs, info = env.reset()
|
||||
assert env.observation_space.contains(
|
||||
reset_obs
|
||||
), f"Obseravtion of {env_spec.id} is not contained in the default observation space {env.observation_space} when excluding current position from observation."
|
||||
@@ -86,7 +86,7 @@ def test_obs_space_mujoco_environments(env_spec: EnvSpec):
|
||||
# Ant-v4 has the option of including contact forces in the observation space with the use_contact_forces argument
|
||||
if env_spec.name == "Ant" and env_spec.version == 4:
|
||||
env = env_spec.make(disable_env_checker=True, use_contact_forces=True)
|
||||
reset_obs = env.reset()
|
||||
reset_obs, info = env.reset()
|
||||
assert env.observation_space.contains(
|
||||
reset_obs
|
||||
), f"Obseravtion of {env_spec.id} is not contained in the default observation space {env.observation_space} when using contact forces."
|
||||
|
@@ -21,17 +21,9 @@ class UnittestEnv(core.Env):
|
||||
observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8)
|
||||
action_space = spaces.Discrete(3)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None
|
||||
):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
super().reset(seed=seed)
|
||||
if return_info:
|
||||
return self.observation_space.sample(), {"info": "dummy"}
|
||||
return self.observation_space.sample() # Dummy observation
|
||||
return self.observation_space.sample(), {"info": "dummy"}
|
||||
|
||||
def step(self, action):
|
||||
observation = self.observation_space.sample() # Dummy observation
|
||||
@@ -45,22 +37,13 @@ class UnknownSpacesEnv(core.Env):
|
||||
on external resources), it is not encouraged.
|
||||
"""
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None
|
||||
):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
super().reset(seed=seed)
|
||||
self.observation_space = spaces.Box(
|
||||
low=0, high=255, shape=(64, 64, 3), dtype=np.uint8
|
||||
)
|
||||
self.action_space = spaces.Discrete(3)
|
||||
if not return_info:
|
||||
return self.observation_space.sample() # Dummy observation
|
||||
else:
|
||||
return self.observation_space.sample(), {} # Dummy observation with info
|
||||
return self.observation_space.sample(), {} # Dummy observation with info
|
||||
|
||||
def step(self, action):
|
||||
observation = self.observation_space.sample() # Dummy observation
|
||||
|
@@ -12,16 +12,12 @@ def basic_reset_fn(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||
"""A basic reset function that will pass the environment check using random actions from the observation space."""
|
||||
super(GenericTestEnv, self).reset(seed=seed)
|
||||
self.observation_space.seed(seed)
|
||||
if return_info:
|
||||
return self.observation_space.sample(), {"options": options}
|
||||
else:
|
||||
return self.observation_space.sample()
|
||||
return self.observation_space.sample(), {"options": options}
|
||||
|
||||
|
||||
def new_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
|
||||
@@ -77,7 +73,6 @@ class GenericTestEnv(gym.Env):
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None,
|
||||
) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||
# If you need a default working reset function, use `basic_reset_fn` above
|
||||
|
@@ -1,16 +1,21 @@
|
||||
"""Tests that the `env_checker` runs as expects and all errors are possible."""
|
||||
import re
|
||||
import warnings
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import gym
|
||||
from gym import spaces
|
||||
from gym.core import ObsType
|
||||
from gym.utils.env_checker import (
|
||||
check_env,
|
||||
check_reset_info,
|
||||
check_reset_options,
|
||||
check_reset_return_info_deprecation,
|
||||
check_reset_return_type,
|
||||
check_reset_seed,
|
||||
check_seed_deprecation,
|
||||
)
|
||||
from tests.testing_env import GenericTestEnv
|
||||
|
||||
@@ -48,29 +53,27 @@ def test_no_error_warnings(env):
|
||||
assert len(warnings) == 0, [warning.message for warning in warnings]
|
||||
|
||||
|
||||
def _no_super_reset(self, seed=None, return_info=False, options=None):
|
||||
def _no_super_reset(self, seed=None, options=None):
|
||||
self.np_random.random() # generates a new prng
|
||||
# generate seed deterministic result
|
||||
self.observation_space.seed(0)
|
||||
return self.observation_space.sample()
|
||||
return self.observation_space.sample(), {}
|
||||
|
||||
|
||||
def _super_reset_fixed(self, seed=None, return_info=False, options=None):
|
||||
def _super_reset_fixed(self, seed=None, options=None):
|
||||
# Call super that ignores the seed passed, use fixed seed
|
||||
super(GenericTestEnv, self).reset(seed=1)
|
||||
# deterministic output
|
||||
self.observation_space._np_random = self.np_random
|
||||
return self.observation_space.sample()
|
||||
return self.observation_space.sample(), {}
|
||||
|
||||
|
||||
def _reset_default_seed(
|
||||
self: GenericTestEnv, seed="Error", return_info=False, options=None
|
||||
):
|
||||
def _reset_default_seed(self: GenericTestEnv, seed="Error", options=None):
|
||||
super(GenericTestEnv, self).reset(seed=seed)
|
||||
self.observation_space._np_random = ( # pyright: ignore [reportPrivateUsage]
|
||||
self.np_random
|
||||
)
|
||||
return self.observation_space.sample()
|
||||
return self.observation_space.sample(), {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -78,12 +81,12 @@ def _reset_default_seed(
|
||||
[
|
||||
[
|
||||
gym.error.Error,
|
||||
lambda self: self.observation_space.sample(),
|
||||
lambda self: (self.observation_space.sample(), {}),
|
||||
"The `reset` method does not provide a `seed` or `**kwargs` keyword argument.",
|
||||
],
|
||||
[
|
||||
AssertionError,
|
||||
lambda self, seed, *_: self.observation_space.sample(),
|
||||
lambda self, seed, *_: (self.observation_space.sample(), {}),
|
||||
"Expects the random number generator to have been generated given a seed was passed to reset. Mostly likely the environment reset function does not call `super().reset(seed=seed)`.",
|
||||
],
|
||||
[
|
||||
@@ -115,95 +118,125 @@ def test_check_reset_seed(test, func: callable, message: str):
|
||||
check_reset_seed(GenericTestEnv(reset_fn=func))
|
||||
|
||||
|
||||
def _deprecated_return_info(
|
||||
self, return_info: bool = False
|
||||
) -> Union[Tuple[ObsType, dict], ObsType]:
|
||||
"""function to simulate the signature and behavior of a `reset` function with the deprecated `return_info` optional argument"""
|
||||
if return_info:
|
||||
return self.observation_space.sample(), {}
|
||||
else:
|
||||
return self.observation_space.sample()
|
||||
|
||||
|
||||
def _reset_var_keyword_kwargs(self, kwargs):
|
||||
return self.observation_space.sample()
|
||||
return self.observation_space.sample(), {}
|
||||
|
||||
|
||||
def _reset_return_info_type(self, seed=None, return_info=False, options=None):
|
||||
if return_info:
|
||||
return [1, 2]
|
||||
else:
|
||||
return self.observation_space.sample()
|
||||
def _reset_return_info_type(self, seed=None, options=None):
|
||||
"""Returns a `list` instead of a `tuple`. This function is used to make sure `env_checker` correctly
|
||||
checks that the return type of `env.reset()` is a `tuple`"""
|
||||
return [self.observation_space.sample(), {}]
|
||||
|
||||
|
||||
def _reset_return_info_length(self, seed=None, return_info=False, options=None):
|
||||
if return_info:
|
||||
return 1, 2, 3
|
||||
else:
|
||||
return self.observation_space.sample()
|
||||
def _reset_return_info_length(self, seed=None, options=None):
|
||||
return 1, 2, 3
|
||||
|
||||
|
||||
def _return_info_obs_outside(self, seed=None, return_info=False, options=None):
|
||||
if return_info:
|
||||
return self.observation_space.sample() + self.observation_space.high, {}
|
||||
else:
|
||||
return self.observation_space.sample()
|
||||
def _return_info_obs_outside(self, seed=None, options=None):
|
||||
return self.observation_space.sample() + self.observation_space.high, {}
|
||||
|
||||
|
||||
def _return_info_not_dict(self, seed=None, return_info=False, options=None):
|
||||
if return_info:
|
||||
return self.observation_space.sample(), ["key", "value"]
|
||||
else:
|
||||
return self.observation_space.sample()
|
||||
def _return_info_not_dict(self, seed=None, options=None):
|
||||
return self.observation_space.sample(), ["key", "value"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test,func,message",
|
||||
[
|
||||
[
|
||||
gym.error.Error,
|
||||
lambda self, *_: self.observation_space.sample(),
|
||||
"The `reset` method does not provide a `return_info` or `**kwargs` keyword argument.",
|
||||
],
|
||||
[
|
||||
gym.error.Error,
|
||||
_reset_var_keyword_kwargs,
|
||||
"The `reset` method does not provide a `return_info` or `**kwargs` keyword argument.",
|
||||
],
|
||||
[
|
||||
AssertionError,
|
||||
_reset_return_info_type,
|
||||
"Calling the reset method with `return_info=True` did not return a tuple, actual type: <class 'list'>",
|
||||
"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `<class 'list'>`",
|
||||
],
|
||||
[
|
||||
AssertionError,
|
||||
_reset_return_info_length,
|
||||
"Calling the reset method with `return_info=True` did not return a 2-tuple, actual length: 3",
|
||||
"Calling the reset method did not return a 2-tuple, actual length: 3",
|
||||
],
|
||||
[
|
||||
AssertionError,
|
||||
_return_info_obs_outside,
|
||||
"The first element returned by `env.reset(return_info=True)` is not within the observation space.",
|
||||
"The first element returned by `env.reset()` is not within the observation space.",
|
||||
],
|
||||
[
|
||||
AssertionError,
|
||||
_return_info_not_dict,
|
||||
"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: <class 'list'>",
|
||||
"The second element returned by `env.reset()` was not a dictionary, actual type: <class 'list'>",
|
||||
],
|
||||
],
|
||||
)
|
||||
def test_check_reset_info(test, func: callable, message: str):
|
||||
"""Tests the check reset info function works as expected."""
|
||||
if test is UserWarning:
|
||||
with pytest.warns(
|
||||
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
|
||||
):
|
||||
check_reset_info(GenericTestEnv(reset_fn=func))
|
||||
else:
|
||||
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
||||
check_reset_info(GenericTestEnv(reset_fn=func))
|
||||
def test_check_reset_return_type(test, func: callable, message: str):
|
||||
"""Tests the check `env.reset()` function has a correct return type."""
|
||||
|
||||
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
||||
check_reset_return_type(GenericTestEnv(reset_fn=func))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test,func,message",
|
||||
[
|
||||
[
|
||||
UserWarning,
|
||||
_deprecated_return_info,
|
||||
"`return_info` is deprecated as an optional argument to `reset`. `reset`"
|
||||
"should now always return `obs, info` where `obs` is an observation, and `info` is a dictionary"
|
||||
"containing additional information.",
|
||||
],
|
||||
],
|
||||
)
|
||||
def test_check_reset_return_info_deprecation(test, func: callable, message: str):
|
||||
"""Tests that return_info has been correct deprecated as an argument to `env.reset()`."""
|
||||
|
||||
with pytest.warns(test, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"):
|
||||
check_reset_return_info_deprecation(GenericTestEnv(reset_fn=func))
|
||||
|
||||
|
||||
def test_check_seed_deprecation():
|
||||
"""Tests that `check_seed_deprecation()` throws a warning if `env.seed()` has not been removed."""
|
||||
|
||||
message = """Official support for the `seed` function is dropped. Standard practice is to reset gym environments using `env.reset(seed=<desired seed>)`"""
|
||||
|
||||
env = GenericTestEnv()
|
||||
|
||||
def seed(seed):
|
||||
return
|
||||
|
||||
with pytest.warns(
|
||||
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
|
||||
):
|
||||
env.seed = seed
|
||||
assert callable(env.seed)
|
||||
check_seed_deprecation(env)
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
env.seed = []
|
||||
check_seed_deprecation(env)
|
||||
env.seed = 123
|
||||
check_seed_deprecation(env)
|
||||
del env.seed
|
||||
check_seed_deprecation(env)
|
||||
assert len(caught_warnings) == 0
|
||||
|
||||
|
||||
def test_check_reset_options():
|
||||
"""Tests the check_reset_options function."""
|
||||
|
||||
with pytest.raises(
|
||||
gym.error.Error,
|
||||
match=re.escape(
|
||||
"The `reset` method does not provide an `options` or `**kwargs` keyword argument"
|
||||
),
|
||||
):
|
||||
check_reset_options(GenericTestEnv(reset_fn=lambda self: 0))
|
||||
check_reset_options(GenericTestEnv(reset_fn=lambda self: (0, {})))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@@ -242,24 +242,20 @@ def test_check_obs(test, obs, obs_space: spaces.Space, message: str):
|
||||
assert len(warnings) == 0
|
||||
|
||||
|
||||
def _reset_no_seed(self, return_info=False, options=None):
|
||||
return self.observation_space.sample()
|
||||
def _reset_no_seed(self, options=None):
|
||||
return self.observation_space.sample(), {}
|
||||
|
||||
|
||||
def _reset_seed_default(self, seed="error", return_info=False, options=None):
|
||||
return self.observation_space.sample()
|
||||
def _reset_seed_default(self, seed="error", options=None):
|
||||
return self.observation_space.sample(), {}
|
||||
|
||||
|
||||
def _reset_no_return_info(self, seed=None, options=None):
|
||||
return self.observation_space.sample()
|
||||
|
||||
|
||||
def _reset_no_option(self, seed=None, return_info=False):
|
||||
return self.observation_space.sample()
|
||||
def _reset_no_option(self, seed=None):
|
||||
return self.observation_space.sample(), {}
|
||||
|
||||
|
||||
def _make_reset_results(results):
|
||||
def _reset_result(self, seed=None, return_info=False, options=None):
|
||||
def _reset_result(self, seed=None, options=None):
|
||||
return results
|
||||
|
||||
return _reset_result
|
||||
@@ -280,12 +276,6 @@ def _make_reset_results(results):
|
||||
"The default seed argument in `Env.reset` should be `None`, otherwise the environment will by default always be deterministic. Actual default: seed='error'",
|
||||
{},
|
||||
],
|
||||
[
|
||||
UserWarning,
|
||||
_reset_no_return_info,
|
||||
"Future gym versions will require that `Env.reset` can be passed `return_info` to return information from the environment resetting.",
|
||||
{},
|
||||
],
|
||||
[
|
||||
UserWarning,
|
||||
_reset_no_option,
|
||||
@@ -293,16 +283,16 @@ def _make_reset_results(results):
|
||||
{},
|
||||
],
|
||||
[
|
||||
AssertionError,
|
||||
UserWarning,
|
||||
_make_reset_results([0, {}]),
|
||||
"The result returned by `env.reset(return_info=True)` was not a tuple, actual type: <class 'list'>",
|
||||
{"return_info": True},
|
||||
"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `<class 'list'>`",
|
||||
{},
|
||||
],
|
||||
[
|
||||
AssertionError,
|
||||
_make_reset_results((0, {1, 2})),
|
||||
"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: <class 'set'>",
|
||||
{"return_info": True},
|
||||
_make_reset_results((np.array([0], dtype=np.float32), {1, 2})),
|
||||
"The second element returned by `env.reset()` was not a dictionary, actual type: <class 'set'>",
|
||||
{},
|
||||
],
|
||||
],
|
||||
)
|
||||
@@ -317,6 +307,8 @@ def test_passive_env_reset_checker(test, func: callable, message: str, kwargs: D
|
||||
with pytest.warns(None) as warnings:
|
||||
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
||||
env_reset_passive_checker(GenericTestEnv(reset_fn=func), **kwargs)
|
||||
for warning in warnings:
|
||||
print(warning)
|
||||
assert len(warnings) == 0
|
||||
|
||||
|
||||
|
@@ -28,7 +28,7 @@ def test_reset_async_vector_env(shared_memory):
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
observations = env.reset()
|
||||
observations, infos = env.reset()
|
||||
|
||||
env.close()
|
||||
|
||||
@@ -40,19 +40,7 @@ def test_reset_async_vector_env(shared_memory):
|
||||
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
observations = env.reset(return_info=False)
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
assert isinstance(env.observation_space, Box)
|
||||
assert isinstance(observations, np.ndarray)
|
||||
assert observations.dtype == env.observation_space.dtype
|
||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||
assert observations.shape == env.observation_space.shape
|
||||
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
observations, infos = env.reset(return_info=True)
|
||||
observations, infos = env.reset()
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
@@ -143,7 +131,7 @@ def test_copy_async_vector_env(shared_memory):
|
||||
|
||||
# TODO, these tests do nothing, understand the purpose of the tests and fix them
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=True)
|
||||
observations = env.reset()
|
||||
observations, infos = env.reset()
|
||||
observations[0] = 0
|
||||
|
||||
env.close()
|
||||
@@ -155,7 +143,7 @@ def test_no_copy_async_vector_env(shared_memory):
|
||||
|
||||
# TODO, these tests do nothing, understand the purpose of the tests and fix them
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=False)
|
||||
observations = env.reset()
|
||||
observations, infos = env.reset()
|
||||
observations[0] = 0
|
||||
|
||||
env.close()
|
||||
@@ -268,7 +256,7 @@ def test_custom_space_async_vector_env():
|
||||
env_fns = [make_custom_space_env(i) for i in range(4)]
|
||||
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=False)
|
||||
reset_observations = env.reset()
|
||||
reset_observations, reset_infos = env.reset()
|
||||
|
||||
assert isinstance(env.single_action_space, CustomSpace)
|
||||
assert isinstance(env.action_space, Tuple)
|
||||
|
@@ -12,7 +12,7 @@ class OldStepEnv(gym.Env):
|
||||
self.observation_space = Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
return 0
|
||||
return 0, {}
|
||||
|
||||
def step(self, action):
|
||||
obs = self.observation_space.sample()
|
||||
@@ -28,7 +28,7 @@ class NewStepEnv(gym.Env):
|
||||
self.observation_space = Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
return 0
|
||||
return 0, {}
|
||||
|
||||
def step(self, action):
|
||||
obs = self.observation_space.sample()
|
||||
|
@@ -24,7 +24,7 @@ def test_create_sync_vector_env():
|
||||
def test_reset_sync_vector_env():
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||
env = SyncVectorEnv(env_fns)
|
||||
observations = env.reset()
|
||||
observations, infos = env.reset()
|
||||
env.close()
|
||||
|
||||
assert isinstance(env.observation_space, Box)
|
||||
@@ -35,32 +35,6 @@ def test_reset_sync_vector_env():
|
||||
|
||||
del observations
|
||||
|
||||
env = SyncVectorEnv(env_fns)
|
||||
observations = env.reset(return_info=False)
|
||||
env.close()
|
||||
|
||||
assert isinstance(env.observation_space, Box)
|
||||
assert isinstance(observations, np.ndarray)
|
||||
assert observations.dtype == env.observation_space.dtype
|
||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||
assert observations.shape == env.observation_space.shape
|
||||
|
||||
del observations
|
||||
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||
|
||||
env = SyncVectorEnv(env_fns)
|
||||
observations, infos = env.reset(return_info=True)
|
||||
env.close()
|
||||
|
||||
assert isinstance(env.observation_space, Box)
|
||||
assert isinstance(observations, np.ndarray)
|
||||
assert observations.dtype == env.observation_space.dtype
|
||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||
assert observations.shape == env.observation_space.shape
|
||||
assert isinstance(infos, dict)
|
||||
assert all([isinstance(info, dict) for info in infos])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_single_action_space", [True, False])
|
||||
def test_step_sync_vector_env(use_single_action_space):
|
||||
@@ -145,7 +119,7 @@ def test_custom_space_sync_vector_env():
|
||||
env_fns = [make_custom_space_env(i) for i in range(4)]
|
||||
|
||||
env = SyncVectorEnv(env_fns)
|
||||
reset_observations = env.reset()
|
||||
reset_observations, infos = env.reset()
|
||||
|
||||
assert isinstance(env.single_action_space, CustomSpace)
|
||||
assert isinstance(env.action_space, Tuple)
|
||||
|
@@ -22,8 +22,8 @@ def test_vector_env_equal(shared_memory):
|
||||
assert async_env.action_space == sync_env.action_space
|
||||
assert async_env.single_action_space == sync_env.single_action_space
|
||||
|
||||
async_observations = async_env.reset(seed=0)
|
||||
sync_observations = sync_env.reset(seed=0)
|
||||
async_observations, async_infos = async_env.reset(seed=0)
|
||||
sync_observations, sync_infos = sync_env.reset(seed=0)
|
||||
assert np.all(async_observations == sync_observations)
|
||||
|
||||
for _ in range(num_steps):
|
||||
|
@@ -63,7 +63,7 @@ class UnittestSlowEnv(gym.Env):
|
||||
super().reset(seed=seed)
|
||||
if self.slow_reset > 0:
|
||||
time.sleep(self.slow_reset)
|
||||
return self.observation_space.sample()
|
||||
return self.observation_space.sample(), {}
|
||||
|
||||
def step(self, action):
|
||||
time.sleep(action)
|
||||
@@ -99,7 +99,7 @@ class CustomSpaceEnv(gym.Env):
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
super().reset(seed=seed)
|
||||
return "reset"
|
||||
return "reset", {}
|
||||
|
||||
def step(self, action):
|
||||
observation = f"step({action:s})"
|
||||
|
@@ -86,9 +86,7 @@ def test_atari_preprocessing_grayscale(env, obs_shape):
|
||||
# It is not possible to test the outputs as we are not using actual observations.
|
||||
# todo: update when ale-py is compatible with the ci
|
||||
|
||||
obs = env.reset(seed=0)
|
||||
assert obs in env.observation_space
|
||||
obs, _ = env.reset(seed=0, return_info=True)
|
||||
obs, _ = env.reset(seed=0)
|
||||
assert obs in env.observation_space
|
||||
|
||||
obs, _, _, _ = env.step(env.action_space.sample())
|
||||
@@ -110,7 +108,7 @@ def test_atari_preprocessing_scale(grayscale, scaled, max_test_steps=10):
|
||||
noop_max=0,
|
||||
)
|
||||
|
||||
obs = env.reset()
|
||||
obs, _ = env.reset()
|
||||
|
||||
max_obs = 1 if scaled else 255
|
||||
assert np.all(0 <= obs) and np.all(obs <= max_obs)
|
||||
|
@@ -39,19 +39,10 @@ class DummyResetEnv(gym.Env):
|
||||
{"count": self.count}, # Info
|
||||
)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: Optional[bool] = False,
|
||||
options: Optional[dict] = None
|
||||
):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
"""Resets the DummyEnv to return the count array and info with count."""
|
||||
self.count = 0
|
||||
if not return_info:
|
||||
return np.array([self.count])
|
||||
else:
|
||||
return np.array([self.count]), {"count": self.count}
|
||||
return np.array([self.count]), {"count": self.count}
|
||||
|
||||
|
||||
def unwrap_env(env) -> Generator[gym.Wrapper, None, None]:
|
||||
@@ -113,7 +104,7 @@ def test_autoreset_wrapper_autoreset():
|
||||
env = DummyResetEnv()
|
||||
env = AutoResetWrapper(env)
|
||||
|
||||
obs, info = env.reset(return_info=True)
|
||||
obs, info = env.reset()
|
||||
assert obs == np.array([0])
|
||||
assert info == {"count": 0}
|
||||
|
||||
|
@@ -25,16 +25,10 @@ class FakeEnvironment(gym.Env):
|
||||
image_shape = (32, 32, 3)
|
||||
return np.zeros(image_shape, dtype=np.uint8)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: bool = False,
|
||||
options: Optional[dict] = None
|
||||
):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
super().reset(seed=seed)
|
||||
observation = self.observation_space.sample()
|
||||
return observation if not return_info else (observation, {})
|
||||
return observation, {}
|
||||
|
||||
def step(self, action):
|
||||
del action
|
||||
@@ -79,8 +73,9 @@ class TestFilterObservation:
|
||||
assert tuple(wrapped_env.observation_space.spaces.keys()) == tuple(filter_keys)
|
||||
|
||||
# Check that the added space item is consistent with the added observation.
|
||||
observation = wrapped_env.reset()
|
||||
observation, info = wrapped_env.reset()
|
||||
assert len(observation) == len(filter_keys)
|
||||
assert isinstance(info, dict)
|
||||
|
||||
@pytest.mark.parametrize("filter_keys,error_type,error_match", ERROR_TEST_CASES)
|
||||
def test_raises_with_incorrect_arguments(
|
||||
|
@@ -18,7 +18,7 @@ class FakeEnvironment(gym.Env):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
super().reset(seed=seed)
|
||||
self.observation = self.observation_space.sample()
|
||||
return self.observation
|
||||
return self.observation, {}
|
||||
|
||||
|
||||
OBSERVATION_SPACES = (
|
||||
@@ -67,7 +67,7 @@ class TestFlattenEnvironment:
|
||||
"""
|
||||
env = FakeEnvironment(observation_space=observation_space)
|
||||
wrapped_env = FlattenObservation(env)
|
||||
flattened = wrapped_env.reset()
|
||||
flattened, info = wrapped_env.reset()
|
||||
|
||||
unflattened = unflatten(env.observation_space, flattened)
|
||||
original = env.observation
|
||||
|
@@ -11,11 +11,13 @@ def test_flatten_observation(env_id):
|
||||
env = gym.make(env_id, disable_env_checker=True)
|
||||
wrapped_env = FlattenObservation(env)
|
||||
|
||||
obs = env.reset()
|
||||
wrapped_obs = wrapped_env.reset()
|
||||
obs, info = env.reset()
|
||||
wrapped_obs, wrapped_obs_info = wrapped_env.reset()
|
||||
|
||||
space = spaces.Tuple((spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2)))
|
||||
wrapped_space = spaces.Box(0, 1, [32 + 11 + 2], dtype=np.int64)
|
||||
|
||||
assert space.contains(obs)
|
||||
assert wrapped_space.contains(wrapped_obs)
|
||||
assert isinstance(info, dict)
|
||||
assert isinstance(wrapped_obs_info, dict)
|
||||
|
@@ -33,8 +33,8 @@ def test_frame_stack(env_id, num_stack, lz4_compress):
|
||||
|
||||
dup = gym.make(env_id, disable_env_checker=True)
|
||||
|
||||
obs = env.reset(seed=0)
|
||||
dup_obs = dup.reset(seed=0)
|
||||
obs, _ = env.reset(seed=0)
|
||||
dup_obs, _ = dup.reset(seed=0)
|
||||
assert np.allclose(obs[-1], dup_obs)
|
||||
|
||||
for _ in range(num_stack**2):
|
||||
|
@@ -22,5 +22,5 @@ def test_gray_scale_observation(env_id, keep_dim):
|
||||
else:
|
||||
assert len(wrapped_env.observation_space.shape) == 2
|
||||
|
||||
wrapped_obs = wrapped_env.reset()
|
||||
wrapped_obs, info = wrapped_env.reset()
|
||||
assert wrapped_obs in wrapped_env.observation_space
|
||||
|
@@ -23,7 +23,7 @@ class FakeEnvironment(gym.Env):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
super().reset(seed=seed)
|
||||
observation = self.observation_space.sample()
|
||||
return observation
|
||||
return observation, {}
|
||||
|
||||
def step(self, action):
|
||||
del action
|
||||
@@ -115,5 +115,6 @@ class TestNestedDictWrapper:
|
||||
def test_nested_dicts_ravel(self, observation_space, flat_shape):
|
||||
env = FakeEnvironment(observation_space=observation_space)
|
||||
wrapped_env = FlattenObservation(FilterObservation(env, env.obs_keys))
|
||||
obs = wrapped_env.reset()
|
||||
obs, info = wrapped_env.reset()
|
||||
assert obs.shape == wrapped_env.observation_space.shape
|
||||
assert isinstance(info, dict)
|
||||
|
@@ -1,7 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from numpy.testing import assert_almost_equal
|
||||
|
||||
import gym
|
||||
@@ -24,19 +23,10 @@ class DummyRewardEnv(gym.Env):
|
||||
self.t += 1
|
||||
return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {}
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
return_info: Optional[bool] = False,
|
||||
options: Optional[dict] = None
|
||||
):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
super().reset(seed=seed)
|
||||
self.t = self.return_reward_idx
|
||||
if not return_info:
|
||||
return np.array([self.t])
|
||||
else:
|
||||
return np.array([self.t]), {}
|
||||
return np.array([self.t]), {}
|
||||
|
||||
|
||||
def make_env(return_reward_idx):
|
||||
@@ -47,11 +37,10 @@ def make_env(return_reward_idx):
|
||||
return thunk
|
||||
|
||||
|
||||
@pytest.mark.parametrize("return_info", [False, True])
|
||||
def test_normalize_observation(return_info: bool):
|
||||
def test_normalize_observation():
|
||||
env = DummyRewardEnv(return_reward_idx=0)
|
||||
env = NormalizeObservation(env)
|
||||
env.reset(return_info=return_info)
|
||||
env.reset()
|
||||
env.step(env.action_space.sample())
|
||||
assert_almost_equal(env.obs_rms.mean, 0.5, decimal=4)
|
||||
env.step(env.action_space.sample())
|
||||
@@ -61,13 +50,7 @@ def test_normalize_observation(return_info: bool):
|
||||
def test_normalize_reset_info():
|
||||
env = DummyRewardEnv(return_reward_idx=0)
|
||||
env = NormalizeObservation(env)
|
||||
obs = env.reset()
|
||||
assert isinstance(obs, np.ndarray)
|
||||
del obs
|
||||
obs = env.reset(return_info=False)
|
||||
assert isinstance(obs, np.ndarray)
|
||||
del obs
|
||||
obs, info = env.reset(return_info=True)
|
||||
obs, info = env.reset()
|
||||
assert isinstance(obs, np.ndarray)
|
||||
assert isinstance(info, dict)
|
||||
|
||||
|
@@ -57,8 +57,8 @@ def test_initialise_failures(env, message):
|
||||
env.close()
|
||||
|
||||
|
||||
def _reset_failure(self, seed=None, return_info=False, options=None):
|
||||
return np.array([-1.0], dtype=np.float32)
|
||||
def _reset_failure(self, seed=None, options=None):
|
||||
return np.array([-1.0], dtype=np.float32), {}
|
||||
|
||||
|
||||
def _step_failure(self, action):
|
||||
|
@@ -21,7 +21,7 @@ class FakeEnvironment(gym.Env):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
super().reset(seed=seed)
|
||||
observation = self.observation_space.sample()
|
||||
return observation
|
||||
return observation, {}
|
||||
|
||||
def step(self, action):
|
||||
del action
|
||||
@@ -82,9 +82,10 @@ def test_dict_observation(pixels_only):
|
||||
assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys
|
||||
|
||||
# Check that the added space item is consistent with the added observation.
|
||||
observation = wrapped_env.reset()
|
||||
observation, info = wrapped_env.reset()
|
||||
rgb_observation = observation[pixel_key]
|
||||
|
||||
assert isinstance(info, dict)
|
||||
assert rgb_observation.shape == (height, width, 3)
|
||||
assert rgb_observation.dtype == np.uint8
|
||||
|
||||
@@ -113,9 +114,10 @@ def test_single_array_observation(pixels_only):
|
||||
pixel_key,
|
||||
]
|
||||
|
||||
observation = wrapped_env.reset()
|
||||
observation, info = wrapped_env.reset()
|
||||
depth_observation = observation[pixel_key]
|
||||
|
||||
assert isinstance(info, dict)
|
||||
assert depth_observation.shape == (32, 32, 3)
|
||||
assert depth_observation.dtype == np.uint8
|
||||
|
||||
|
@@ -31,10 +31,7 @@ def test_record_episode_statistics_reset_info():
|
||||
env = gym.make("CartPole-v1", disable_env_checker=True)
|
||||
env = RecordEpisodeStatistics(env)
|
||||
ob_space = env.observation_space
|
||||
obs = env.reset()
|
||||
assert ob_space.contains(obs)
|
||||
del obs
|
||||
obs, info = env.reset(return_info=True)
|
||||
obs, info = env.reset()
|
||||
assert ob_space.contains(obs)
|
||||
assert isinstance(info, dict)
|
||||
|
||||
|
@@ -23,35 +23,17 @@ def test_record_video_using_default_trigger():
|
||||
shutil.rmtree("videos")
|
||||
|
||||
|
||||
def test_record_video_reset_return_info():
|
||||
def test_record_video_reset():
|
||||
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
|
||||
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
|
||||
ob_space = env.observation_space
|
||||
obs, info = env.reset(return_info=True)
|
||||
obs, info = env.reset()
|
||||
env.close()
|
||||
assert os.path.isdir("videos")
|
||||
shutil.rmtree("videos")
|
||||
assert ob_space.contains(obs)
|
||||
assert isinstance(info, dict)
|
||||
|
||||
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
|
||||
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
|
||||
ob_space = env.observation_space
|
||||
obs = env.reset(return_info=False)
|
||||
env.close()
|
||||
assert os.path.isdir("videos")
|
||||
shutil.rmtree("videos")
|
||||
assert ob_space.contains(obs)
|
||||
|
||||
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
|
||||
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
|
||||
ob_space = env.observation_space
|
||||
obs = env.reset()
|
||||
env.close()
|
||||
assert os.path.isdir("videos")
|
||||
shutil.rmtree("videos")
|
||||
assert ob_space.contains(obs)
|
||||
|
||||
|
||||
def test_record_video_step_trigger():
|
||||
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
|
||||
|
@@ -18,8 +18,8 @@ def test_rescale_action():
|
||||
|
||||
seed = 0
|
||||
|
||||
obs = env.reset(seed=seed)
|
||||
wrapped_obs = wrapped_env.reset(seed=seed)
|
||||
obs, info = env.reset(seed=seed)
|
||||
wrapped_obs, wrapped_obs_info = wrapped_env.reset(seed=seed)
|
||||
assert np.allclose(obs, wrapped_obs)
|
||||
|
||||
obs, reward, _, _ = env.step([1.5])
|
||||
|
@@ -13,7 +13,7 @@ def test_resize_observation(env_id, shape):
|
||||
|
||||
assert isinstance(env.observation_space, spaces.Box)
|
||||
assert env.observation_space.shape[-1] == 3
|
||||
obs = env.reset()
|
||||
obs, _ = env.reset()
|
||||
if isinstance(shape, int):
|
||||
assert env.observation_space.shape[:2] == (shape, shape)
|
||||
assert obs.shape == (shape, shape, 3)
|
||||
|
@@ -14,8 +14,8 @@ def test_time_aware_observation(env_id):
|
||||
assert isinstance(wrapped_env.observation_space, spaces.Box)
|
||||
assert wrapped_env.observation_space.shape[0] == env.observation_space.shape[0] + 1
|
||||
|
||||
obs = env.reset()
|
||||
wrapped_obs = wrapped_env.reset()
|
||||
obs, info = env.reset()
|
||||
wrapped_obs, wrapped_obs_info = wrapped_env.reset()
|
||||
assert wrapped_env.t == 0.0
|
||||
assert wrapped_obs[-1] == 0.0
|
||||
assert wrapped_obs.shape[0] == obs.shape[0] + 1
|
||||
@@ -30,7 +30,7 @@ def test_time_aware_observation(env_id):
|
||||
assert wrapped_obs[-1] == 2.0
|
||||
assert wrapped_obs.shape[0] == obs.shape[0] + 1
|
||||
|
||||
wrapped_obs = wrapped_env.reset()
|
||||
wrapped_obs, wrapped_obs_info = wrapped_env.reset()
|
||||
assert wrapped_env.t == 0.0
|
||||
assert wrapped_obs[-1] == 0.0
|
||||
assert wrapped_obs.shape[0] == obs.shape[0] + 1
|
||||
|
@@ -9,13 +9,7 @@ def test_time_limit_reset_info():
|
||||
env = gym.make("CartPole-v1", disable_env_checker=True)
|
||||
env = TimeLimit(env)
|
||||
ob_space = env.observation_space
|
||||
obs = env.reset()
|
||||
assert ob_space.contains(obs)
|
||||
del obs
|
||||
obs = env.reset(return_info=False)
|
||||
assert ob_space.contains(obs)
|
||||
del obs
|
||||
obs, info = env.reset(return_info=True)
|
||||
obs, info = env.reset()
|
||||
assert ob_space.contains(obs)
|
||||
assert isinstance(info, dict)
|
||||
|
||||
|
@@ -15,9 +15,10 @@ def test_transform_observation(env_id):
|
||||
gym.make(env_id, disable_env_checker=True), lambda obs: affine_transform(obs)
|
||||
)
|
||||
|
||||
obs = env.reset(seed=0)
|
||||
wrapped_obs = wrapped_env.reset(seed=0)
|
||||
obs, info = env.reset(seed=0)
|
||||
wrapped_obs, wrapped_obs_info = wrapped_env.reset(seed=0)
|
||||
assert np.allclose(wrapped_obs, affine_transform(obs))
|
||||
assert isinstance(wrapped_obs_info, dict)
|
||||
|
||||
action = env.action_space.sample()
|
||||
obs, reward, done, _ = env.step(action)
|
||||
|
@@ -23,7 +23,7 @@ def test_info_to_list():
|
||||
env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True)
|
||||
wrapped_env = VectorListInfo(env_to_wrap)
|
||||
wrapped_env.action_space.seed(SEED)
|
||||
_, info = wrapped_env.reset(seed=SEED, return_info=True)
|
||||
_, info = wrapped_env.reset(seed=SEED)
|
||||
assert isinstance(info, list)
|
||||
assert len(info) == NUM_ENVS
|
||||
|
||||
@@ -40,7 +40,7 @@ def test_info_to_list():
|
||||
def test_info_to_list_statistics():
|
||||
env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True)
|
||||
wrapped_env = VectorListInfo(RecordEpisodeStatistics(env_to_wrap))
|
||||
_, info = wrapped_env.reset(seed=SEED, return_info=True)
|
||||
_, info = wrapped_env.reset(seed=SEED)
|
||||
wrapped_env.action_space.seed(SEED)
|
||||
assert isinstance(info, list)
|
||||
assert len(info) == NUM_ENVS
|
||||
|
Reference in New Issue
Block a user