mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-29 01:27:29 +00:00
Removing return_info argument to env.reset() and deprecated env.seed() function (reset now always returns info) (#2962)
* removed return_info, made info dict mandatory in reset * tenatively removed deprecated seed api for environments * added more info type checks to wrapper tests * formatting/style compliance * addressed some comments * polish to address review * fixed tests after merge, and added a test of the return_info deprecation assertion if found in reset signature * some organization of env_checker tests, reverted a probably merge error * added deprecation check for seed function in env * updated docstring * removed debug prints, tweaked test_check_seed_deprecation * changed return_info deprecation check from assertion to warning * fixes to vector envs, now should be correctly structured * added some explanation and typehints for mockup depcreated return info reset function * re-removed seed function from vector envs * added explanation to _reset_return_info_type and changed the return statement
This commit is contained in:
@@ -23,14 +23,14 @@ The Gym API's API models environments as simple Python `env` classes. Creating e
|
|||||||
```python
|
```python
|
||||||
import gym
|
import gym
|
||||||
env = gym.make("CartPole-v1")
|
env = gym.make("CartPole-v1")
|
||||||
observation, info = env.reset(seed=42, return_info=True)
|
observation, info = env.reset(seed=42)
|
||||||
|
|
||||||
for _ in range(1000):
|
for _ in range(1000):
|
||||||
action = env.action_space.sample()
|
action = env.action_space.sample()
|
||||||
observation, reward, done, info = env.step(action)
|
observation, reward, done, info = env.step(action)
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
observation, info = env.reset(return_info=True)
|
observation, info = env.reset()
|
||||||
env.close()
|
env.close()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
53
gym/core.py
53
gym/core.py
@@ -41,11 +41,10 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
The main API methods that users of this class need to know are:
|
The main API methods that users of this class need to know are:
|
||||||
|
|
||||||
- :meth:`step` - Takes a step in the environment using an action returning the next observation, reward,
|
- :meth:`step` - Takes a step in the environment using an action returning the next observation, reward,
|
||||||
if the environment terminated and more information.
|
if the environment terminated and observation information.
|
||||||
- :meth:`reset` - Resets the environment to an initial state, returning the initial observation.
|
- :meth:`reset` - Resets the environment to an initial state, returning the initial observation and observation information.
|
||||||
- :meth:`render` - Renders the environment observation with modes depending on the output
|
- :meth:`render` - Renders the environment observation with modes depending on the output
|
||||||
- :meth:`close` - Closes the environment, important for rendering where pygame is imported
|
- :meth:`close` - Closes the environment, important for rendering where pygame is imported
|
||||||
- :meth:`seed` - Seeds the environment's random number generator, :deprecated: in favor of `Env.reset(seed=seed)`.
|
|
||||||
|
|
||||||
And set the following attributes:
|
And set the following attributes:
|
||||||
|
|
||||||
@@ -124,9 +123,8 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
) -> Union[ObsType, Tuple[ObsType, dict]]:
|
) -> Tuple[ObsType, dict]:
|
||||||
"""Resets the environment to an initial state and returns the initial observation.
|
"""Resets the environment to an initial state and returns the initial observation.
|
||||||
|
|
||||||
This method can reset the environment's random number generator(s) if ``seed`` is an integer or
|
This method can reset the environment's random number generator(s) if ``seed`` is an integer or
|
||||||
@@ -143,8 +141,6 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
If you pass an integer, the PRNG will be reset even if it already exists.
|
If you pass an integer, the PRNG will be reset even if it already exists.
|
||||||
Usually, you want to pass an integer *right after the environment has been initialized and then never again*.
|
Usually, you want to pass an integer *right after the environment has been initialized and then never again*.
|
||||||
Please refer to the minimal example above to see this paradigm in action.
|
Please refer to the minimal example above to see this paradigm in action.
|
||||||
return_info (bool): If true, return additional information along with initial observation.
|
|
||||||
This info should be analogous to the info returned in :meth:`step`
|
|
||||||
options (optional dict): Additional information to specify how the environment is reset (optional,
|
options (optional dict): Additional information to specify how the environment is reset (optional,
|
||||||
depending on the specific environment)
|
depending on the specific environment)
|
||||||
|
|
||||||
@@ -152,8 +148,7 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
Returns:
|
Returns:
|
||||||
observation (object): Observation of the initial state. This will be an element of :attr:`observation_space`
|
observation (object): Observation of the initial state. This will be an element of :attr:`observation_space`
|
||||||
(typically a numpy array) and is analogous to the observation returned by :meth:`step`.
|
(typically a numpy array) and is analogous to the observation returned by :meth:`step`.
|
||||||
info (optional dictionary): This will *only* be returned if ``return_info=True`` is passed.
|
info (dictionary): This dictionary contains auxiliary information complementing ``observation``. It should be analogous to
|
||||||
It contains auxiliary information complementing ``observation``. This dictionary should be analogous to
|
|
||||||
the ``info`` returned by :meth:`step`.
|
the ``info`` returned by :meth:`step`.
|
||||||
"""
|
"""
|
||||||
# Initialize the RNG if the seed is manually passed
|
# Initialize the RNG if the seed is manually passed
|
||||||
@@ -193,33 +188,6 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
""":deprecated: function that sets the seed for the environment's random number generator(s).
|
|
||||||
|
|
||||||
Use `env.reset(seed=seed)` as the new API for setting the seed of the environment.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
Some environments use multiple pseudorandom number generators.
|
|
||||||
We want to capture all such seeds used in order to ensure that
|
|
||||||
there aren't accidental correlations between multiple generators.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seed(Optional int): The seed value for the random number generator
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
seeds (List[int]): Returns the list of seeds used in this environment's random
|
|
||||||
number generators. The first value in the list should be the
|
|
||||||
"main" seed, or the value which a reproducer should pass to
|
|
||||||
'seed'. Often, the main seed equals the provided 'seed', but
|
|
||||||
this won't be true `if seed=None`, for example.
|
|
||||||
"""
|
|
||||||
deprecation(
|
|
||||||
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
|
|
||||||
"Please use `env.reset(seed=seed)` instead."
|
|
||||||
)
|
|
||||||
self._np_random, seed = seeding.np_random(seed)
|
|
||||||
return [seed]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unwrapped(self) -> "Env":
|
def unwrapped(self) -> "Env":
|
||||||
"""Returns the base non-wrapped environment.
|
"""Returns the base non-wrapped environment.
|
||||||
@@ -370,7 +338,7 @@ class Wrapper(Env[ObsType, ActType]):
|
|||||||
|
|
||||||
return step_api_compatibility(self.env.step(action), self.new_step_api)
|
return step_api_compatibility(self.env.step(action), self.new_step_api)
|
||||||
|
|
||||||
def reset(self, **kwargs) -> Union[ObsType, Tuple[ObsType, dict]]:
|
def reset(self, **kwargs) -> Tuple[ObsType, dict]:
|
||||||
"""Resets the environment with kwargs."""
|
"""Resets the environment with kwargs."""
|
||||||
return self.env.reset(**kwargs)
|
return self.env.reset(**kwargs)
|
||||||
|
|
||||||
@@ -384,10 +352,6 @@ class Wrapper(Env[ObsType, ActType]):
|
|||||||
"""Closes the environment."""
|
"""Closes the environment."""
|
||||||
return self.env.close()
|
return self.env.close()
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
"""Seeds the environment."""
|
|
||||||
return self.env.seed(seed)
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
"""Returns the wrapper name and the unwrapped environment string."""
|
"""Returns the wrapper name and the unwrapped environment string."""
|
||||||
return f"<{type(self).__name__}{self.env}>"
|
return f"<{type(self).__name__}{self.env}>"
|
||||||
@@ -432,11 +396,8 @@ class ObservationWrapper(Wrapper):
|
|||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
"""Resets the environment, returning a modified observation using :meth:`self.observation`."""
|
"""Resets the environment, returning a modified observation using :meth:`self.observation`."""
|
||||||
if kwargs.get("return_info", False):
|
obs, info = self.env.reset(**kwargs)
|
||||||
obs, info = self.env.reset(**kwargs)
|
return self.observation(obs), info
|
||||||
return self.observation(obs), info
|
|
||||||
else:
|
|
||||||
return self.observation(self.env.reset(**kwargs))
|
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""
|
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""
|
||||||
|
@@ -428,7 +428,6 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
@@ -514,10 +513,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
|
|
||||||
self.lidar = [LidarCallback() for _ in range(10)]
|
self.lidar = [LidarCallback() for _ in range(10)]
|
||||||
self.renderer.reset()
|
self.renderer.reset()
|
||||||
if not return_info:
|
return self.step(np.array([0, 0, 0, 0]))[0], {}
|
||||||
return self.step(np.array([0, 0, 0, 0]))[0]
|
|
||||||
else:
|
|
||||||
return self.step(np.array([0, 0, 0, 0]))[0], {}
|
|
||||||
|
|
||||||
def step(self, action: np.ndarray):
|
def step(self, action: np.ndarray):
|
||||||
assert self.hull is not None
|
assert self.hull is not None
|
||||||
|
@@ -483,7 +483,6 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
@@ -519,10 +518,7 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
self.car = Car(self.world, *self.track[0][1:4])
|
self.car = Car(self.world, *self.track[0][1:4])
|
||||||
|
|
||||||
self.renderer.reset()
|
self.renderer.reset()
|
||||||
if not return_info:
|
return self.step(None)[0], {}
|
||||||
return self.step(None)[0]
|
|
||||||
else:
|
|
||||||
return self.step(None)[0], {}
|
|
||||||
|
|
||||||
def step(self, action: Union[np.ndarray, int]):
|
def step(self, action: Union[np.ndarray, int]):
|
||||||
assert self.car is not None
|
assert self.car is not None
|
||||||
|
@@ -305,7 +305,6 @@ class LunarLander(gym.Env, EzPickle):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
@@ -413,10 +412,7 @@ class LunarLander(gym.Env, EzPickle):
|
|||||||
self.drawlist = [self.lander] + self.legs
|
self.drawlist = [self.lander] + self.legs
|
||||||
|
|
||||||
self.renderer.reset()
|
self.renderer.reset()
|
||||||
if not return_info:
|
return self.step(np.array([0, 0]) if self.continuous else 0)[0], {}
|
||||||
return self.step(np.array([0, 0]) if self.continuous else 0)[0]
|
|
||||||
else:
|
|
||||||
return self.step(np.array([0, 0]) if self.continuous else 0)[0], {}
|
|
||||||
|
|
||||||
def _create_particle(self, mass, x, y, ttl):
|
def _create_particle(self, mass, x, y, ttl):
|
||||||
p = self.world.CreateDynamicBody(
|
p = self.world.CreateDynamicBody(
|
||||||
@@ -774,7 +770,7 @@ def demo_heuristic_lander(env, seed=None, render=False):
|
|||||||
|
|
||||||
total_reward = 0
|
total_reward = 0
|
||||||
steps = 0
|
steps = 0
|
||||||
s = env.reset(seed=seed)
|
s, info = env.reset(seed=seed)
|
||||||
while True:
|
while True:
|
||||||
a = heuristic(env, s)
|
a = heuristic(env, s)
|
||||||
s, r, terminated, truncated, info = step_api_compatibility(env.step(a), True)
|
s, r, terminated, truncated, info = step_api_compatibility(env.step(a), True)
|
||||||
|
@@ -180,13 +180,7 @@ class AcrobotEnv(core.Env):
|
|||||||
self.action_space = spaces.Discrete(3)
|
self.action_space = spaces.Discrete(3)
|
||||||
self.state = None
|
self.state = None
|
||||||
|
|
||||||
def reset(
|
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||||
self,
|
|
||||||
*,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None
|
|
||||||
):
|
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
# Note that if you use custom reset bounds, it may lead to out-of-bound
|
# Note that if you use custom reset bounds, it may lead to out-of-bound
|
||||||
# state/observations.
|
# state/observations.
|
||||||
@@ -199,10 +193,7 @@ class AcrobotEnv(core.Env):
|
|||||||
|
|
||||||
self.renderer.reset()
|
self.renderer.reset()
|
||||||
self.renderer.render_step()
|
self.renderer.render_step()
|
||||||
if not return_info:
|
return self._get_ob(), {}
|
||||||
return self._get_ob()
|
|
||||||
else:
|
|
||||||
return self._get_ob(), {}
|
|
||||||
|
|
||||||
def step(self, a):
|
def step(self, a):
|
||||||
s = self.state
|
s = self.state
|
||||||
|
@@ -192,7 +192,6 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
@@ -205,10 +204,7 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
|||||||
self.steps_beyond_terminated = None
|
self.steps_beyond_terminated = None
|
||||||
self.renderer.reset()
|
self.renderer.reset()
|
||||||
self.renderer.render_step()
|
self.renderer.render_step()
|
||||||
if not return_info:
|
return np.array(self.state, dtype=np.float32), {}
|
||||||
return np.array(self.state, dtype=np.float32)
|
|
||||||
else:
|
|
||||||
return np.array(self.state, dtype=np.float32), {}
|
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
return self.renderer.get_renders()
|
return self.renderer.get_renders()
|
||||||
|
@@ -174,13 +174,7 @@ class Continuous_MountainCarEnv(gym.Env):
|
|||||||
self.renderer.render_step()
|
self.renderer.render_step()
|
||||||
return self.state, reward, terminated, False, {}
|
return self.state, reward, terminated, False, {}
|
||||||
|
|
||||||
def reset(
|
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||||
self,
|
|
||||||
*,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None
|
|
||||||
):
|
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
# Note that if you use custom reset bounds, it may lead to out-of-bound
|
# Note that if you use custom reset bounds, it may lead to out-of-bound
|
||||||
# state/observations.
|
# state/observations.
|
||||||
@@ -188,10 +182,7 @@ class Continuous_MountainCarEnv(gym.Env):
|
|||||||
self.state = np.array([self.np_random.uniform(low=low, high=high), 0])
|
self.state = np.array([self.np_random.uniform(low=low, high=high), 0])
|
||||||
self.renderer.reset()
|
self.renderer.reset()
|
||||||
self.renderer.render_step()
|
self.renderer.render_step()
|
||||||
if not return_info:
|
return np.array(self.state, dtype=np.float32), {}
|
||||||
return np.array(self.state, dtype=np.float32)
|
|
||||||
else:
|
|
||||||
return np.array(self.state, dtype=np.float32), {}
|
|
||||||
|
|
||||||
def _height(self, xs):
|
def _height(self, xs):
|
||||||
return np.sin(3 * xs) * 0.45 + 0.55
|
return np.sin(3 * xs) * 0.45 + 0.55
|
||||||
|
@@ -152,7 +152,6 @@ class MountainCarEnv(gym.Env):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
@@ -162,10 +161,7 @@ class MountainCarEnv(gym.Env):
|
|||||||
self.state = np.array([self.np_random.uniform(low=low, high=high), 0])
|
self.state = np.array([self.np_random.uniform(low=low, high=high), 0])
|
||||||
self.renderer.reset()
|
self.renderer.reset()
|
||||||
self.renderer.render_step()
|
self.renderer.render_step()
|
||||||
if not return_info:
|
return np.array(self.state, dtype=np.float32), {}
|
||||||
return np.array(self.state, dtype=np.float32)
|
|
||||||
else:
|
|
||||||
return np.array(self.state, dtype=np.float32), {}
|
|
||||||
|
|
||||||
def _height(self, xs):
|
def _height(self, xs):
|
||||||
return np.sin(3 * xs) * 0.45 + 0.55
|
return np.sin(3 * xs) * 0.45 + 0.55
|
||||||
|
@@ -138,13 +138,7 @@ class PendulumEnv(gym.Env):
|
|||||||
self.renderer.render_step()
|
self.renderer.render_step()
|
||||||
return self._get_obs(), -costs, False, False, {}
|
return self._get_obs(), -costs, False, False, {}
|
||||||
|
|
||||||
def reset(
|
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||||
self,
|
|
||||||
*,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None
|
|
||||||
):
|
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
if options is None:
|
if options is None:
|
||||||
high = np.array([DEFAULT_X, DEFAULT_Y])
|
high = np.array([DEFAULT_X, DEFAULT_Y])
|
||||||
@@ -162,10 +156,7 @@ class PendulumEnv(gym.Env):
|
|||||||
|
|
||||||
self.renderer.reset()
|
self.renderer.reset()
|
||||||
self.renderer.render_step()
|
self.renderer.render_step()
|
||||||
if not return_info:
|
return self._get_obs(), {}
|
||||||
return self._get_obs()
|
|
||||||
else:
|
|
||||||
return self._get_obs(), {}
|
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
theta, thetadot = self.state
|
theta, thetadot = self.state
|
||||||
|
@@ -142,7 +142,6 @@ class BaseMujocoEnv(gym.Env):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
@@ -152,10 +151,7 @@ class BaseMujocoEnv(gym.Env):
|
|||||||
ob = self.reset_model()
|
ob = self.reset_model()
|
||||||
self.renderer.reset()
|
self.renderer.reset()
|
||||||
self.renderer.render_step()
|
self.renderer.render_step()
|
||||||
if not return_info:
|
return ob, {}
|
||||||
return ob
|
|
||||||
else:
|
|
||||||
return ob, {}
|
|
||||||
|
|
||||||
def set_state(self, qpos, qvel):
|
def set_state(self, qpos, qvel):
|
||||||
"""
|
"""
|
||||||
|
@@ -167,7 +167,6 @@ class BlackjackEnv(gym.Env):
|
|||||||
def reset(
|
def reset(
|
||||||
self,
|
self,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
@@ -189,10 +188,7 @@ class BlackjackEnv(gym.Env):
|
|||||||
self.renderer.reset()
|
self.renderer.reset()
|
||||||
self.renderer.render_step()
|
self.renderer.render_step()
|
||||||
|
|
||||||
if not return_info:
|
return self._get_obs(), {}
|
||||||
return self._get_obs()
|
|
||||||
else:
|
|
||||||
return self._get_obs(), {}
|
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
return self.renderer.get_renders()
|
return self.renderer.get_renders()
|
||||||
|
@@ -149,22 +149,14 @@ class CliffWalkingEnv(Env):
|
|||||||
self.renderer.render_step()
|
self.renderer.render_step()
|
||||||
return (int(s), r, t, False, {"prob": p})
|
return (int(s), r, t, False, {"prob": p})
|
||||||
|
|
||||||
def reset(
|
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||||
self,
|
|
||||||
*,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None
|
|
||||||
):
|
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
|
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
|
||||||
self.lastaction = None
|
self.lastaction = None
|
||||||
self.renderer.reset()
|
self.renderer.reset()
|
||||||
self.renderer.render_step()
|
self.renderer.render_step()
|
||||||
if not return_info:
|
|
||||||
return int(self.s)
|
return int(self.s), {"prob": 1}
|
||||||
else:
|
|
||||||
return int(self.s), {"prob": 1}
|
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
return self.renderer.get_renders()
|
return self.renderer.get_renders()
|
||||||
|
@@ -256,7 +256,6 @@ class FrozenLakeEnv(Env):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
@@ -266,10 +265,7 @@ class FrozenLakeEnv(Env):
|
|||||||
self.renderer.reset()
|
self.renderer.reset()
|
||||||
self.renderer.render_step()
|
self.renderer.render_step()
|
||||||
|
|
||||||
if not return_info:
|
return int(self.s), {"prob": 1}
|
||||||
return int(self.s)
|
|
||||||
else:
|
|
||||||
return int(self.s), {"prob": 1}
|
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
return self.renderer.get_renders()
|
return self.renderer.get_renders()
|
||||||
|
@@ -89,7 +89,7 @@ class TaxiEnv(Env):
|
|||||||
|
|
||||||
### Info
|
### Info
|
||||||
|
|
||||||
``step`` and ``reset(return_info=True)`` will return an info dictionary that contains "p" and "action_mask" containing
|
``step`` and ``reset()`` will return an info dictionary that contains "p" and "action_mask" containing
|
||||||
the probability that the state is taken and a mask of what actions will result in a change of state to speed up training.
|
the probability that the state is taken and a mask of what actions will result in a change of state to speed up training.
|
||||||
|
|
||||||
As Taxi's initial state is a stochastic, the "p" key represents the probability of the
|
As Taxi's initial state is a stochastic, the "p" key represents the probability of the
|
||||||
@@ -266,7 +266,6 @@ class TaxiEnv(Env):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
@@ -275,10 +274,8 @@ class TaxiEnv(Env):
|
|||||||
self.taxi_orientation = 0
|
self.taxi_orientation = 0
|
||||||
self.renderer.reset()
|
self.renderer.reset()
|
||||||
self.renderer.render_step()
|
self.renderer.render_step()
|
||||||
if not return_info:
|
|
||||||
return int(self.s)
|
return int(self.s), {"prob": 1.0, "action_mask": self.action_mask(self.s)}
|
||||||
else:
|
|
||||||
return int(self.s), {"prob": 1.0, "action_mask": self.action_mask(self.s)}
|
|
||||||
|
|
||||||
def render(self):
|
def render(self):
|
||||||
return self.renderer.get_renders()
|
return self.renderer.get_renders()
|
||||||
|
@@ -73,7 +73,7 @@ def check_reset_seed(env: gym.Env):
|
|||||||
and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD
|
and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
obs_1 = env.reset(seed=123)
|
obs_1, info = env.reset(seed=123)
|
||||||
assert (
|
assert (
|
||||||
obs_1 in env.observation_space
|
obs_1 in env.observation_space
|
||||||
), "The observation returned by `env.reset(seed=123)` is not within the observation space."
|
), "The observation returned by `env.reset(seed=123)` is not within the observation space."
|
||||||
@@ -85,7 +85,7 @@ def check_reset_seed(env: gym.Env):
|
|||||||
env.unwrapped._np_random # pyright: ignore [reportPrivateUsage]
|
env.unwrapped._np_random # pyright: ignore [reportPrivateUsage]
|
||||||
)
|
)
|
||||||
|
|
||||||
obs_2 = env.reset(seed=123)
|
obs_2, info = env.reset(seed=123)
|
||||||
assert (
|
assert (
|
||||||
obs_2 in env.observation_space
|
obs_2 in env.observation_space
|
||||||
), "The observation returned by `env.reset(seed=123)` is not within the observation space."
|
), "The observation returned by `env.reset(seed=123)` is not within the observation space."
|
||||||
@@ -98,7 +98,7 @@ def check_reset_seed(env: gym.Env):
|
|||||||
== seed_123_rng.bit_generator.state
|
== seed_123_rng.bit_generator.state
|
||||||
), "Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random generates are not same when the same seeds are passed to `env.reset`."
|
), "Mostly likely the environment reset function does not call `super().reset(seed=seed)` as the random generates are not same when the same seeds are passed to `env.reset`."
|
||||||
|
|
||||||
obs_3 = env.reset(seed=456)
|
obs_3, info = env.reset(seed=456)
|
||||||
assert (
|
assert (
|
||||||
obs_3 in env.observation_space
|
obs_3 in env.observation_space
|
||||||
), "The observation returned by `env.reset(seed=456)` is not within the observation space."
|
), "The observation returned by `env.reset(seed=456)` is not within the observation space."
|
||||||
@@ -126,53 +126,6 @@ def check_reset_seed(env: gym.Env):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_reset_info(env: gym.Env):
|
|
||||||
"""Checks that :meth:`reset` supports the ``return_info`` keyword.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
env: The environment to check
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AssertionError: The environment cannot be reset with `return_info=True`,
|
|
||||||
even though `return_info` or `kwargs` appear in the signature.
|
|
||||||
"""
|
|
||||||
signature = inspect.signature(env.reset)
|
|
||||||
if "return_info" in signature.parameters or (
|
|
||||||
"kwargs" in signature.parameters
|
|
||||||
and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
obs = env.reset(return_info=False)
|
|
||||||
assert (
|
|
||||||
obs in env.observation_space
|
|
||||||
), "The value returned by `env.reset(return_info=True)` is not within the observation space."
|
|
||||||
|
|
||||||
result = env.reset(return_info=True)
|
|
||||||
assert isinstance(
|
|
||||||
result, tuple
|
|
||||||
), f"Calling the reset method with `return_info=True` did not return a tuple, actual type: {type(result)}"
|
|
||||||
assert (
|
|
||||||
len(result) == 2
|
|
||||||
), f"Calling the reset method with `return_info=True` did not return a 2-tuple, actual length: {len(result)}"
|
|
||||||
|
|
||||||
obs, info = result
|
|
||||||
assert (
|
|
||||||
obs in env.observation_space
|
|
||||||
), "The first element returned by `env.reset(return_info=True)` is not within the observation space."
|
|
||||||
assert isinstance(
|
|
||||||
info, dict
|
|
||||||
), f"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: {type(info)}"
|
|
||||||
except TypeError as e:
|
|
||||||
raise AssertionError(
|
|
||||||
"The environment cannot be reset with `return_info=True`, even though `return_info` or `kwargs` appear in the signature. "
|
|
||||||
f"This should never happen, please report this issue. The error was: {e}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise gym.error.Error(
|
|
||||||
"The `reset` method does not provide a `return_info` or `**kwargs` keyword argument."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def check_reset_options(env: gym.Env):
|
def check_reset_options(env: gym.Env):
|
||||||
"""Check that the environment can be reset with options.
|
"""Check that the environment can be reset with options.
|
||||||
|
|
||||||
@@ -201,6 +154,64 @@ def check_reset_options(env: gym.Env):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_reset_return_info_deprecation(env: gym.Env):
|
||||||
|
"""Makes sure support for deprecated `return_info` argument is dropped.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to check
|
||||||
|
Raises:
|
||||||
|
UserWarning
|
||||||
|
"""
|
||||||
|
signature = inspect.signature(env.reset)
|
||||||
|
if "return_info" in signature.parameters:
|
||||||
|
logger.warn(
|
||||||
|
"`return_info` is deprecated as an optional argument to `reset`. `reset`"
|
||||||
|
"should now always return `obs, info` where `obs` is an observation, and `info` is a dictionary"
|
||||||
|
"containing additional information."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_seed_deprecation(env: gym.Env):
|
||||||
|
"""Makes sure support for deprecated function `seed` is dropped.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to check
|
||||||
|
Raises:
|
||||||
|
UserWarning
|
||||||
|
"""
|
||||||
|
seed_fn = getattr(env, "seed", None)
|
||||||
|
if callable(seed_fn):
|
||||||
|
logger.warn(
|
||||||
|
"Official support for the `seed` function is dropped. "
|
||||||
|
"Standard practice is to reset gym environments using `env.reset(seed=<desired seed>)`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_reset_return_type(env: gym.Env):
|
||||||
|
"""Checks that :meth:`reset` correctly returns a tuple of the form `(obs , info)`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to check
|
||||||
|
Raises:
|
||||||
|
AssertionError depending on spec violation
|
||||||
|
"""
|
||||||
|
result = env.reset()
|
||||||
|
assert isinstance(
|
||||||
|
result, tuple
|
||||||
|
), f"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `{type(result)}`"
|
||||||
|
assert (
|
||||||
|
len(result) == 2
|
||||||
|
), f"Calling the reset method did not return a 2-tuple, actual length: {len(result)}"
|
||||||
|
|
||||||
|
obs, info = result
|
||||||
|
assert (
|
||||||
|
obs in env.observation_space
|
||||||
|
), "The first element returned by `env.reset()` is not within the observation space."
|
||||||
|
assert isinstance(
|
||||||
|
info, dict
|
||||||
|
), f"The second element returned by `env.reset()` was not a dictionary, actual type: {type(info)}"
|
||||||
|
|
||||||
|
|
||||||
def check_space_limit(space, space_type: str):
|
def check_space_limit(space, space_type: str):
|
||||||
"""Check the space limit for only the Box space as a test that only runs as part of `check_env`."""
|
"""Check the space limit for only the Box space as a test that only runs as part of `check_env`."""
|
||||||
if isinstance(space, spaces.Box):
|
if isinstance(space, spaces.Box):
|
||||||
@@ -279,9 +290,11 @@ def check_env(env: gym.Env, warn: bool = None, skip_render_check: bool = False):
|
|||||||
check_space_limit(env.observation_space, "observation")
|
check_space_limit(env.observation_space, "observation")
|
||||||
|
|
||||||
# ==== Check the reset method ====
|
# ==== Check the reset method ====
|
||||||
|
check_seed_deprecation(env)
|
||||||
|
check_reset_return_info_deprecation(env)
|
||||||
|
check_reset_return_type(env)
|
||||||
check_reset_seed(env)
|
check_reset_seed(env)
|
||||||
check_reset_options(env)
|
check_reset_options(env)
|
||||||
check_reset_info(env)
|
|
||||||
|
|
||||||
# ============ Check the returned values ===============
|
# ============ Check the returned values ===============
|
||||||
env_reset_passive_checker(env)
|
env_reset_passive_checker(env)
|
||||||
|
@@ -183,14 +183,6 @@ def env_reset_passive_checker(env, **kwargs):
|
|||||||
f"Actual default: {seed_param}"
|
f"Actual default: {seed_param}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if "return_info" not in signature.parameters and not (
|
|
||||||
"kwargs" in signature.parameters
|
|
||||||
and signature.parameters["kwargs"].kind is inspect.Parameter.VAR_KEYWORD
|
|
||||||
):
|
|
||||||
logger.warn(
|
|
||||||
"Future gym versions will require that `Env.reset` can be passed `return_info` to return information from the environment resetting."
|
|
||||||
)
|
|
||||||
|
|
||||||
if "options" not in signature.parameters and "kwargs" not in signature.parameters:
|
if "options" not in signature.parameters and "kwargs" not in signature.parameters:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Future gym versions will require that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information."
|
"Future gym versions will require that `Env.reset` can be passed `options` to allow the environment initialisation to be passed additional information."
|
||||||
@@ -198,21 +190,17 @@ def env_reset_passive_checker(env, **kwargs):
|
|||||||
|
|
||||||
# Checks the result of env.reset with kwargs
|
# Checks the result of env.reset with kwargs
|
||||||
result = env.reset(**kwargs)
|
result = env.reset(**kwargs)
|
||||||
if kwargs.get("return_info", False) is True:
|
|
||||||
assert isinstance(
|
|
||||||
result, tuple
|
|
||||||
), f"The result returned by `env.reset(return_info=True)` was not a tuple, actual type: {type(result)}"
|
|
||||||
assert (
|
|
||||||
len(result) == 2
|
|
||||||
), f"The length of the result returned by `env.reset(return_info=True)` is not 2, actual length: {len(result)}"
|
|
||||||
obs, info = result
|
|
||||||
assert isinstance(
|
|
||||||
info, dict
|
|
||||||
), f"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: {type(info)}"
|
|
||||||
else:
|
|
||||||
obs = result
|
|
||||||
|
|
||||||
|
if not isinstance(result, tuple):
|
||||||
|
logger.warn(
|
||||||
|
f"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `{type(result)}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
obs, info = result
|
||||||
check_obs(obs, env.observation_space, "reset")
|
check_obs(obs, env.observation_space, "reset")
|
||||||
|
assert isinstance(
|
||||||
|
info, dict
|
||||||
|
), f"The second element returned by `env.reset()` was not a dictionary, actual type: {type(info)}"
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@@ -171,38 +171,9 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
self._state = AsyncState.DEFAULT
|
self._state = AsyncState.DEFAULT
|
||||||
self._check_spaces()
|
self._check_spaces()
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
"""Seeds the vector environments.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seed: The seeds use with the environments
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AlreadyPendingCallError: Calling `seed` while waiting for a pending call to complete
|
|
||||||
"""
|
|
||||||
super().seed(seed=seed)
|
|
||||||
self._assert_is_running()
|
|
||||||
if seed is None:
|
|
||||||
seed = [None for _ in range(self.num_envs)]
|
|
||||||
if isinstance(seed, int):
|
|
||||||
seed = [seed + i for i in range(self.num_envs)]
|
|
||||||
assert len(seed) == self.num_envs
|
|
||||||
|
|
||||||
if self._state != AsyncState.DEFAULT:
|
|
||||||
raise AlreadyPendingCallError(
|
|
||||||
f"Calling `seed` while waiting for a pending call to `{self._state.value}` to complete.",
|
|
||||||
self._state.value,
|
|
||||||
)
|
|
||||||
|
|
||||||
for pipe, seed in zip(self.parent_pipes, seed):
|
|
||||||
pipe.send(("seed", seed))
|
|
||||||
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
|
||||||
self._raise_if_errors(successes)
|
|
||||||
|
|
||||||
def reset_async(
|
def reset_async(
|
||||||
self,
|
self,
|
||||||
seed: Optional[Union[int, List[int]]] = None,
|
seed: Optional[Union[int, List[int]]] = None,
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
"""Send calls to the :obj:`reset` methods of the sub-environments.
|
"""Send calls to the :obj:`reset` methods of the sub-environments.
|
||||||
@@ -211,7 +182,6 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
seed: List of seeds for each environment
|
seed: List of seeds for each environment
|
||||||
return_info: If to return information
|
|
||||||
options: The reset option
|
options: The reset option
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@@ -238,8 +208,6 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
single_kwargs = {}
|
single_kwargs = {}
|
||||||
if single_seed is not None:
|
if single_seed is not None:
|
||||||
single_kwargs["seed"] = single_seed
|
single_kwargs["seed"] = single_seed
|
||||||
if return_info:
|
|
||||||
single_kwargs["return_info"] = return_info
|
|
||||||
if options is not None:
|
if options is not None:
|
||||||
single_kwargs["options"] = options
|
single_kwargs["options"] = options
|
||||||
|
|
||||||
@@ -250,7 +218,6 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
self,
|
self,
|
||||||
timeout: Optional[Union[int, float]] = None,
|
timeout: Optional[Union[int, float]] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
) -> Union[ObsType, Tuple[ObsType, List[dict]]]:
|
) -> Union[ObsType, Tuple[ObsType, List[dict]]]:
|
||||||
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
|
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
|
||||||
@@ -258,7 +225,6 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
Args:
|
Args:
|
||||||
timeout: Number of seconds before the call to `reset_wait` times out. If `None`, the call to `reset_wait` never times out.
|
timeout: Number of seconds before the call to `reset_wait` times out. If `None`, the call to `reset_wait` never times out.
|
||||||
seed: ignored
|
seed: ignored
|
||||||
return_info: If to return information
|
|
||||||
options: ignored
|
options: ignored
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -286,27 +252,17 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
self._raise_if_errors(successes)
|
self._raise_if_errors(successes)
|
||||||
self._state = AsyncState.DEFAULT
|
self._state = AsyncState.DEFAULT
|
||||||
|
|
||||||
if return_info:
|
infos = {}
|
||||||
infos = {}
|
results, info_data = zip(*results)
|
||||||
results, info_data = zip(*results)
|
for i, info in enumerate(info_data):
|
||||||
for i, info in enumerate(info_data):
|
infos = self._add_info(infos, info, i)
|
||||||
infos = self._add_info(infos, info, i)
|
|
||||||
|
|
||||||
if not self.shared_memory:
|
if not self.shared_memory:
|
||||||
self.observations = concatenate(
|
self.observations = concatenate(
|
||||||
self.single_observation_space, results, self.observations
|
self.single_observation_space, results, self.observations
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (deepcopy(self.observations) if self.copy else self.observations), infos
|
||||||
deepcopy(self.observations) if self.copy else self.observations
|
|
||||||
), infos
|
|
||||||
else:
|
|
||||||
if not self.shared_memory:
|
|
||||||
self.observations = concatenate(
|
|
||||||
self.single_observation_space, results, self.observations
|
|
||||||
)
|
|
||||||
|
|
||||||
return deepcopy(self.observations) if self.copy else self.observations
|
|
||||||
|
|
||||||
def step_async(self, actions: np.ndarray):
|
def step_async(self, actions: np.ndarray):
|
||||||
"""Send the calls to :obj:`step` to each sub-environment.
|
"""Send the calls to :obj:`step` to each sub-environment.
|
||||||
@@ -606,12 +562,8 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
|||||||
while True:
|
while True:
|
||||||
command, data = pipe.recv()
|
command, data = pipe.recv()
|
||||||
if command == "reset":
|
if command == "reset":
|
||||||
if "return_info" in data and data["return_info"] is True:
|
observation, info = env.reset(**data)
|
||||||
observation, info = env.reset(**data)
|
pipe.send(((observation, info), True))
|
||||||
pipe.send(((observation, info), True))
|
|
||||||
else:
|
|
||||||
observation = env.reset(**data)
|
|
||||||
pipe.send((observation, True))
|
|
||||||
|
|
||||||
elif command == "step":
|
elif command == "step":
|
||||||
(
|
(
|
||||||
@@ -622,8 +574,9 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
|||||||
info,
|
info,
|
||||||
) = step_api_compatibility(env.step(data), True)
|
) = step_api_compatibility(env.step(data), True)
|
||||||
if terminated or truncated:
|
if terminated or truncated:
|
||||||
info["final_observation"] = observation
|
old_observation = observation
|
||||||
observation = env.reset()
|
observation, info = env.reset()
|
||||||
|
info["final_observation"] = old_observation
|
||||||
pipe.send(((observation, reward, terminated, truncated, info), True))
|
pipe.send(((observation, reward, terminated, truncated, info), True))
|
||||||
elif command == "seed":
|
elif command == "seed":
|
||||||
env.seed(data)
|
env.seed(data)
|
||||||
@@ -676,18 +629,12 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
|
|||||||
while True:
|
while True:
|
||||||
command, data = pipe.recv()
|
command, data = pipe.recv()
|
||||||
if command == "reset":
|
if command == "reset":
|
||||||
if "return_info" in data and data["return_info"] is True:
|
observation, info = env.reset(**data)
|
||||||
observation, info = env.reset(**data)
|
write_to_shared_memory(
|
||||||
write_to_shared_memory(
|
observation_space, index, observation, shared_memory
|
||||||
observation_space, index, observation, shared_memory
|
)
|
||||||
)
|
pipe.send(((None, info), True))
|
||||||
pipe.send(((None, info), True))
|
|
||||||
else:
|
|
||||||
observation = env.reset(**data)
|
|
||||||
write_to_shared_memory(
|
|
||||||
observation_space, index, observation, shared_memory
|
|
||||||
)
|
|
||||||
pipe.send((None, True))
|
|
||||||
elif command == "step":
|
elif command == "step":
|
||||||
(
|
(
|
||||||
observation,
|
observation,
|
||||||
@@ -697,8 +644,10 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
|
|||||||
info,
|
info,
|
||||||
) = step_api_compatibility(env.step(data), True)
|
) = step_api_compatibility(env.step(data), True)
|
||||||
if terminated or truncated:
|
if terminated or truncated:
|
||||||
info["final_observation"] = observation
|
old_observation = observation
|
||||||
observation = env.reset()
|
observation, info = env.reset()
|
||||||
|
info["final_observation"] = old_observation
|
||||||
|
|
||||||
write_to_shared_memory(
|
write_to_shared_memory(
|
||||||
observation_space, index, observation, shared_memory
|
observation_space, index, observation, shared_memory
|
||||||
)
|
)
|
||||||
|
@@ -93,14 +93,12 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
def reset_wait(
|
def reset_wait(
|
||||||
self,
|
self,
|
||||||
seed: Optional[Union[int, List[int]]] = None,
|
seed: Optional[Union[int, List[int]]] = None,
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
|
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seed: The reset environment seed
|
seed: The reset environment seed
|
||||||
return_info: If to return information
|
|
||||||
options: Option information for the environment reset
|
options: Option information for the environment reset
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -123,26 +121,15 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
kwargs["seed"] = single_seed
|
kwargs["seed"] = single_seed
|
||||||
if options is not None:
|
if options is not None:
|
||||||
kwargs["options"] = options
|
kwargs["options"] = options
|
||||||
if return_info is True:
|
|
||||||
kwargs["return_info"] = return_info
|
|
||||||
|
|
||||||
if not return_info:
|
observation, info = env.reset(**kwargs)
|
||||||
observation = env.reset(**kwargs)
|
observations.append(observation)
|
||||||
observations.append(observation)
|
infos = self._add_info(infos, info, i)
|
||||||
else:
|
|
||||||
observation, info = env.reset(**kwargs)
|
|
||||||
observations.append(observation)
|
|
||||||
infos = self._add_info(infos, info, i)
|
|
||||||
|
|
||||||
self.observations = concatenate(
|
self.observations = concatenate(
|
||||||
self.single_observation_space, observations, self.observations
|
self.single_observation_space, observations, self.observations
|
||||||
)
|
)
|
||||||
if not return_info:
|
return (deepcopy(self.observations) if self.copy else self.observations), infos
|
||||||
return deepcopy(self.observations) if self.copy else self.observations
|
|
||||||
else:
|
|
||||||
return (
|
|
||||||
deepcopy(self.observations) if self.copy else self.observations
|
|
||||||
), infos
|
|
||||||
|
|
||||||
def step_async(self, actions):
|
def step_async(self, actions):
|
||||||
"""Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version."""
|
"""Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version."""
|
||||||
@@ -164,8 +151,9 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
info,
|
info,
|
||||||
) = step_api_compatibility(env.step(action), True)
|
) = step_api_compatibility(env.step(action), True)
|
||||||
if self._terminateds[i] or self._truncateds[i]:
|
if self._terminateds[i] or self._truncateds[i]:
|
||||||
info["final_observation"] = observation
|
old_observation = observation
|
||||||
observation = env.reset()
|
observation, info = env.reset()
|
||||||
|
info["final_observation"] = old_observation
|
||||||
observations.append(observation)
|
observations.append(observation)
|
||||||
infos = self._add_info(infos, info, i)
|
infos = self._add_info(infos, info, i)
|
||||||
self.observations = concatenate(
|
self.observations = concatenate(
|
||||||
|
@@ -60,7 +60,6 @@ class VectorEnv(gym.Env):
|
|||||||
def reset_async(
|
def reset_async(
|
||||||
self,
|
self,
|
||||||
seed: Optional[Union[int, List[int]]] = None,
|
seed: Optional[Union[int, List[int]]] = None,
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
"""Reset the sub-environments asynchronously.
|
"""Reset the sub-environments asynchronously.
|
||||||
@@ -70,7 +69,6 @@ class VectorEnv(gym.Env):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
seed: The reset seed
|
seed: The reset seed
|
||||||
return_info: If to return info
|
|
||||||
options: Reset options
|
options: Reset options
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
@@ -78,7 +76,6 @@ class VectorEnv(gym.Env):
|
|||||||
def reset_wait(
|
def reset_wait(
|
||||||
self,
|
self,
|
||||||
seed: Optional[Union[int, List[int]]] = None,
|
seed: Optional[Union[int, List[int]]] = None,
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
"""Retrieves the results of a :meth:`reset_async` call.
|
"""Retrieves the results of a :meth:`reset_async` call.
|
||||||
@@ -87,7 +84,6 @@ class VectorEnv(gym.Env):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
seed: The reset seed
|
seed: The reset seed
|
||||||
return_info: If to return info
|
|
||||||
options: Reset options
|
options: Reset options
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -102,21 +98,19 @@ class VectorEnv(gym.Env):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
seed: Optional[Union[int, List[int]]] = None,
|
seed: Optional[Union[int, List[int]]] = None,
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
"""Reset all parallel environments and return a batch of initial observations.
|
"""Reset all parallel environments and return a batch of initial observations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seed: The environment reset seeds
|
seed: The environment reset seeds
|
||||||
return_info: If to return the info
|
|
||||||
options: If to return the options
|
options: If to return the options
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A batch of observations from the vectorized environment.
|
A batch of observations from the vectorized environment.
|
||||||
"""
|
"""
|
||||||
self.reset_async(seed=seed, return_info=return_info, options=options)
|
self.reset_async(seed=seed, options=options)
|
||||||
return self.reset_wait(seed=seed, return_info=return_info, options=options)
|
return self.reset_wait(seed=seed, options=options)
|
||||||
|
|
||||||
def step_async(self, actions):
|
def step_async(self, actions):
|
||||||
"""Asynchronously performs steps in the sub-environments.
|
"""Asynchronously performs steps in the sub-environments.
|
||||||
@@ -220,21 +214,6 @@ class VectorEnv(gym.Env):
|
|||||||
self.close_extras(**kwargs)
|
self.close_extras(**kwargs)
|
||||||
self.closed = True
|
self.closed = True
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
"""Set the random seed in all parallel environments.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seed: Random seed for each parallel environment. If ``seed`` is a list of
|
|
||||||
length ``num_envs``, then the items of the list are chosen as random
|
|
||||||
seeds. If ``seed`` is an int, then each parallel environment uses the random
|
|
||||||
seed ``seed + n``, where ``n`` is the index of the parallel environment
|
|
||||||
(between ``0`` and ``num_envs - 1``).
|
|
||||||
"""
|
|
||||||
deprecation(
|
|
||||||
"Function `env.seed(seed)` is marked as deprecated and will be removed in the future. "
|
|
||||||
"Please use `env.reset(seed=seed) instead in VectorEnvs."
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_info(self, infos: dict, info: dict, env_num: int) -> dict:
|
def _add_info(self, infos: dict, info: dict, env_num: int) -> dict:
|
||||||
"""Add env info to the info dictionary of the vectorized environment.
|
"""Add env info to the info dictionary of the vectorized environment.
|
||||||
|
|
||||||
@@ -339,9 +318,6 @@ class VectorEnvWrapper(VectorEnv):
|
|||||||
def close_extras(self, **kwargs):
|
def close_extras(self, **kwargs):
|
||||||
return self.env.close_extras(**kwargs)
|
return self.env.close_extras(**kwargs)
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
return self.env.seed(seed)
|
|
||||||
|
|
||||||
def call(self, name, *args, **kwargs):
|
def call(self, name, *args, **kwargs):
|
||||||
return self.env.call(name, *args, **kwargs)
|
return self.env.call(name, *args, **kwargs)
|
||||||
|
|
||||||
|
@@ -151,11 +151,7 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
"""Resets the environment using preprocessing."""
|
"""Resets the environment using preprocessing."""
|
||||||
# NoopReset
|
# NoopReset
|
||||||
if kwargs.get("return_info", False):
|
_, reset_info = self.env.reset(**kwargs)
|
||||||
_, reset_info = self.env.reset(**kwargs)
|
|
||||||
else:
|
|
||||||
_ = self.env.reset(**kwargs)
|
|
||||||
reset_info = {}
|
|
||||||
|
|
||||||
noops = (
|
noops = (
|
||||||
self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
|
self.env.unwrapped.np_random.integers(1, self.noop_max + 1)
|
||||||
@@ -168,11 +164,7 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
)
|
)
|
||||||
reset_info.update(step_info)
|
reset_info.update(step_info)
|
||||||
if terminated or truncated:
|
if terminated or truncated:
|
||||||
if kwargs.get("return_info", False):
|
_, reset_info = self.env.reset(**kwargs)
|
||||||
_, reset_info = self.env.reset(**kwargs)
|
|
||||||
else:
|
|
||||||
_ = self.env.reset(**kwargs)
|
|
||||||
reset_info = {}
|
|
||||||
|
|
||||||
self.lives = self.ale.lives()
|
self.lives = self.ale.lives()
|
||||||
if self.grayscale_obs:
|
if self.grayscale_obs:
|
||||||
@@ -181,10 +173,7 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
self.ale.getScreenRGB(self.obs_buffer[0])
|
self.ale.getScreenRGB(self.obs_buffer[0])
|
||||||
self.obs_buffer[1].fill(0)
|
self.obs_buffer[1].fill(0)
|
||||||
|
|
||||||
if kwargs.get("return_info", False):
|
return self._get_obs(), reset_info
|
||||||
return self._get_obs(), reset_info
|
|
||||||
else:
|
|
||||||
return self._get_obs()
|
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
if self.frame_skip > 1: # more efficient in-place pooling
|
if self.frame_skip > 1: # more efficient in-place pooling
|
||||||
|
@@ -48,7 +48,7 @@ class AutoResetWrapper(gym.Wrapper):
|
|||||||
|
|
||||||
if terminated or truncated:
|
if terminated or truncated:
|
||||||
|
|
||||||
new_obs, new_info = self.env.reset(return_info=True)
|
new_obs, new_info = self.env.reset()
|
||||||
assert (
|
assert (
|
||||||
"final_observation" not in new_info
|
"final_observation" not in new_info
|
||||||
), 'info dict cannot contain key "final_observation" '
|
), 'info dict cannot contain key "final_observation" '
|
||||||
|
@@ -191,14 +191,8 @@ class FrameStack(gym.ObservationWrapper):
|
|||||||
Returns:
|
Returns:
|
||||||
The stacked observations
|
The stacked observations
|
||||||
"""
|
"""
|
||||||
if kwargs.get("return_info", False):
|
obs, info = self.env.reset(**kwargs)
|
||||||
obs, info = self.env.reset(**kwargs)
|
|
||||||
else:
|
|
||||||
obs = self.env.reset(**kwargs)
|
|
||||||
info = None # Unused
|
|
||||||
[self.frames.append(obs) for _ in range(self.num_stack)]
|
[self.frames.append(obs) for _ in range(self.num_stack)]
|
||||||
|
|
||||||
if kwargs.get("return_info", False):
|
return self.observation(None), info
|
||||||
return self.observation(None), info
|
|
||||||
else:
|
|
||||||
return self.observation(None)
|
|
||||||
|
@@ -89,20 +89,12 @@ class NormalizeObservation(gym.core.Wrapper):
|
|||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
"""Resets the environment and normalizes the observation."""
|
"""Resets the environment and normalizes the observation."""
|
||||||
if kwargs.get("return_info", False):
|
obs, info = self.env.reset(**kwargs)
|
||||||
obs, info = self.env.reset(**kwargs)
|
|
||||||
|
|
||||||
if self.is_vector_env:
|
if self.is_vector_env:
|
||||||
return self.normalize(obs), info
|
return self.normalize(obs), info
|
||||||
else:
|
|
||||||
return self.normalize(np.array([obs]))[0], info
|
|
||||||
else:
|
else:
|
||||||
obs = self.env.reset(**kwargs)
|
return self.normalize(np.array([obs]))[0], info
|
||||||
|
|
||||||
if self.is_vector_env:
|
|
||||||
return self.normalize(obs)
|
|
||||||
else:
|
|
||||||
return self.normalize(np.array([obs]))[0]
|
|
||||||
|
|
||||||
def normalize(self, obs):
|
def normalize(self, obs):
|
||||||
"""Normalises the observation using the running mean and variance of the observations."""
|
"""Normalises the observation using the running mean and variance of the observations."""
|
||||||
|
@@ -57,9 +57,6 @@ class VectorListInfo(gym.Wrapper):
|
|||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
"""Resets the environment using kwargs."""
|
"""Resets the environment using kwargs."""
|
||||||
if not kwargs.get("return_info"):
|
|
||||||
return self.env.reset(**kwargs)
|
|
||||||
|
|
||||||
obs, infos = self.env.reset(**kwargs)
|
obs, infos = self.env.reset(**kwargs)
|
||||||
list_info = self._convert_info_to_list(infos)
|
list_info = self._convert_info_to_list(infos)
|
||||||
return obs, list_info
|
return obs, list_info
|
||||||
|
@@ -139,7 +139,7 @@ def test_taxi_action_mask():
|
|||||||
def test_taxi_encode_decode():
|
def test_taxi_encode_decode():
|
||||||
env = TaxiEnv()
|
env = TaxiEnv()
|
||||||
|
|
||||||
state = env.reset()
|
state, info = env.reset()
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
assert (
|
assert (
|
||||||
env.encode(*env.decode(state)) == state
|
env.encode(*env.decode(state)) == state
|
||||||
|
@@ -82,8 +82,8 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
|
|||||||
env_1 = env_spec.make(disable_env_checker=True)
|
env_1 = env_spec.make(disable_env_checker=True)
|
||||||
env_2 = env_spec.make(disable_env_checker=True)
|
env_2 = env_spec.make(disable_env_checker=True)
|
||||||
|
|
||||||
initial_obs_1 = env_1.reset(seed=SEED)
|
initial_obs_1, initial_info_1 = env_1.reset(seed=SEED)
|
||||||
initial_obs_2 = env_2.reset(seed=SEED)
|
initial_obs_2, initial_info_2 = env_2.reset(seed=SEED)
|
||||||
assert_equals(initial_obs_1, initial_obs_2)
|
assert_equals(initial_obs_1, initial_obs_2)
|
||||||
|
|
||||||
env_1.action_space.seed(SEED)
|
env_1.action_space.seed(SEED)
|
||||||
|
@@ -17,8 +17,8 @@ def verify_environments_match(
|
|||||||
old_env = envs.make(old_env_id, disable_env_checker=True)
|
old_env = envs.make(old_env_id, disable_env_checker=True)
|
||||||
new_env = envs.make(new_env_id, disable_env_checker=True)
|
new_env = envs.make(new_env_id, disable_env_checker=True)
|
||||||
|
|
||||||
old_reset_obs = old_env.reset(seed=seed)
|
old_reset_obs, old_info = old_env.reset(seed=seed)
|
||||||
new_reset_obs = new_env.reset(seed=seed)
|
new_reset_obs, new_info = new_env.reset(seed=seed)
|
||||||
|
|
||||||
np.testing.assert_allclose(old_reset_obs, new_reset_obs)
|
np.testing.assert_allclose(old_reset_obs, new_reset_obs)
|
||||||
|
|
||||||
@@ -56,7 +56,7 @@ EXCLUDE_POS_FROM_OBS = [
|
|||||||
def test_obs_space_mujoco_environments(env_spec: EnvSpec):
|
def test_obs_space_mujoco_environments(env_spec: EnvSpec):
|
||||||
"""Check that the returned observations are contained in the observation space of the environment"""
|
"""Check that the returned observations are contained in the observation space of the environment"""
|
||||||
env = env_spec.make(disable_env_checker=True)
|
env = env_spec.make(disable_env_checker=True)
|
||||||
reset_obs = env.reset()
|
reset_obs, info = env.reset()
|
||||||
assert env.observation_space.contains(
|
assert env.observation_space.contains(
|
||||||
reset_obs
|
reset_obs
|
||||||
), f"Obseravtion returned by reset() of {env_spec.id} is not contained in the default observation space {env.observation_space}."
|
), f"Obseravtion returned by reset() of {env_spec.id} is not contained in the default observation space {env.observation_space}."
|
||||||
@@ -73,7 +73,7 @@ def test_obs_space_mujoco_environments(env_spec: EnvSpec):
|
|||||||
env = env_spec.make(
|
env = env_spec.make(
|
||||||
disable_env_checker=True, exclude_current_positions_from_observation=False
|
disable_env_checker=True, exclude_current_positions_from_observation=False
|
||||||
)
|
)
|
||||||
reset_obs = env.reset()
|
reset_obs, info = env.reset()
|
||||||
assert env.observation_space.contains(
|
assert env.observation_space.contains(
|
||||||
reset_obs
|
reset_obs
|
||||||
), f"Obseravtion of {env_spec.id} is not contained in the default observation space {env.observation_space} when excluding current position from observation."
|
), f"Obseravtion of {env_spec.id} is not contained in the default observation space {env.observation_space} when excluding current position from observation."
|
||||||
@@ -86,7 +86,7 @@ def test_obs_space_mujoco_environments(env_spec: EnvSpec):
|
|||||||
# Ant-v4 has the option of including contact forces in the observation space with the use_contact_forces argument
|
# Ant-v4 has the option of including contact forces in the observation space with the use_contact_forces argument
|
||||||
if env_spec.name == "Ant" and env_spec.version == 4:
|
if env_spec.name == "Ant" and env_spec.version == 4:
|
||||||
env = env_spec.make(disable_env_checker=True, use_contact_forces=True)
|
env = env_spec.make(disable_env_checker=True, use_contact_forces=True)
|
||||||
reset_obs = env.reset()
|
reset_obs, info = env.reset()
|
||||||
assert env.observation_space.contains(
|
assert env.observation_space.contains(
|
||||||
reset_obs
|
reset_obs
|
||||||
), f"Obseravtion of {env_spec.id} is not contained in the default observation space {env.observation_space} when using contact forces."
|
), f"Obseravtion of {env_spec.id} is not contained in the default observation space {env.observation_space} when using contact forces."
|
||||||
|
@@ -21,17 +21,9 @@ class UnittestEnv(core.Env):
|
|||||||
observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8)
|
observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8)
|
||||||
action_space = spaces.Discrete(3)
|
action_space = spaces.Discrete(3)
|
||||||
|
|
||||||
def reset(
|
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||||
self,
|
|
||||||
*,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None
|
|
||||||
):
|
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
if return_info:
|
return self.observation_space.sample(), {"info": "dummy"}
|
||||||
return self.observation_space.sample(), {"info": "dummy"}
|
|
||||||
return self.observation_space.sample() # Dummy observation
|
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
observation = self.observation_space.sample() # Dummy observation
|
observation = self.observation_space.sample() # Dummy observation
|
||||||
@@ -45,22 +37,13 @@ class UnknownSpacesEnv(core.Env):
|
|||||||
on external resources), it is not encouraged.
|
on external resources), it is not encouraged.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def reset(
|
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||||
self,
|
|
||||||
*,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None
|
|
||||||
):
|
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(
|
||||||
low=0, high=255, shape=(64, 64, 3), dtype=np.uint8
|
low=0, high=255, shape=(64, 64, 3), dtype=np.uint8
|
||||||
)
|
)
|
||||||
self.action_space = spaces.Discrete(3)
|
self.action_space = spaces.Discrete(3)
|
||||||
if not return_info:
|
return self.observation_space.sample(), {} # Dummy observation with info
|
||||||
return self.observation_space.sample() # Dummy observation
|
|
||||||
else:
|
|
||||||
return self.observation_space.sample(), {} # Dummy observation with info
|
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
observation = self.observation_space.sample() # Dummy observation
|
observation = self.observation_space.sample() # Dummy observation
|
||||||
|
@@ -12,16 +12,12 @@ def basic_reset_fn(
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
) -> Union[ObsType, Tuple[ObsType, dict]]:
|
) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||||
"""A basic reset function that will pass the environment check using random actions from the observation space."""
|
"""A basic reset function that will pass the environment check using random actions from the observation space."""
|
||||||
super(GenericTestEnv, self).reset(seed=seed)
|
super(GenericTestEnv, self).reset(seed=seed)
|
||||||
self.observation_space.seed(seed)
|
self.observation_space.seed(seed)
|
||||||
if return_info:
|
return self.observation_space.sample(), {"options": options}
|
||||||
return self.observation_space.sample(), {"options": options}
|
|
||||||
else:
|
|
||||||
return self.observation_space.sample()
|
|
||||||
|
|
||||||
|
|
||||||
def new_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
|
def new_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
|
||||||
@@ -77,7 +73,6 @@ class GenericTestEnv(gym.Env):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None,
|
options: Optional[dict] = None,
|
||||||
) -> Union[ObsType, Tuple[ObsType, dict]]:
|
) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||||
# If you need a default working reset function, use `basic_reset_fn` above
|
# If you need a default working reset function, use `basic_reset_fn` above
|
||||||
|
@@ -1,16 +1,21 @@
|
|||||||
"""Tests that the `env_checker` runs as expects and all errors are possible."""
|
"""Tests that the `env_checker` runs as expects and all errors are possible."""
|
||||||
import re
|
import re
|
||||||
|
import warnings
|
||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
|
from gym.core import ObsType
|
||||||
from gym.utils.env_checker import (
|
from gym.utils.env_checker import (
|
||||||
check_env,
|
check_env,
|
||||||
check_reset_info,
|
|
||||||
check_reset_options,
|
check_reset_options,
|
||||||
|
check_reset_return_info_deprecation,
|
||||||
|
check_reset_return_type,
|
||||||
check_reset_seed,
|
check_reset_seed,
|
||||||
|
check_seed_deprecation,
|
||||||
)
|
)
|
||||||
from tests.testing_env import GenericTestEnv
|
from tests.testing_env import GenericTestEnv
|
||||||
|
|
||||||
@@ -48,29 +53,27 @@ def test_no_error_warnings(env):
|
|||||||
assert len(warnings) == 0, [warning.message for warning in warnings]
|
assert len(warnings) == 0, [warning.message for warning in warnings]
|
||||||
|
|
||||||
|
|
||||||
def _no_super_reset(self, seed=None, return_info=False, options=None):
|
def _no_super_reset(self, seed=None, options=None):
|
||||||
self.np_random.random() # generates a new prng
|
self.np_random.random() # generates a new prng
|
||||||
# generate seed deterministic result
|
# generate seed deterministic result
|
||||||
self.observation_space.seed(0)
|
self.observation_space.seed(0)
|
||||||
return self.observation_space.sample()
|
return self.observation_space.sample(), {}
|
||||||
|
|
||||||
|
|
||||||
def _super_reset_fixed(self, seed=None, return_info=False, options=None):
|
def _super_reset_fixed(self, seed=None, options=None):
|
||||||
# Call super that ignores the seed passed, use fixed seed
|
# Call super that ignores the seed passed, use fixed seed
|
||||||
super(GenericTestEnv, self).reset(seed=1)
|
super(GenericTestEnv, self).reset(seed=1)
|
||||||
# deterministic output
|
# deterministic output
|
||||||
self.observation_space._np_random = self.np_random
|
self.observation_space._np_random = self.np_random
|
||||||
return self.observation_space.sample()
|
return self.observation_space.sample(), {}
|
||||||
|
|
||||||
|
|
||||||
def _reset_default_seed(
|
def _reset_default_seed(self: GenericTestEnv, seed="Error", options=None):
|
||||||
self: GenericTestEnv, seed="Error", return_info=False, options=None
|
|
||||||
):
|
|
||||||
super(GenericTestEnv, self).reset(seed=seed)
|
super(GenericTestEnv, self).reset(seed=seed)
|
||||||
self.observation_space._np_random = ( # pyright: ignore [reportPrivateUsage]
|
self.observation_space._np_random = ( # pyright: ignore [reportPrivateUsage]
|
||||||
self.np_random
|
self.np_random
|
||||||
)
|
)
|
||||||
return self.observation_space.sample()
|
return self.observation_space.sample(), {}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -78,12 +81,12 @@ def _reset_default_seed(
|
|||||||
[
|
[
|
||||||
[
|
[
|
||||||
gym.error.Error,
|
gym.error.Error,
|
||||||
lambda self: self.observation_space.sample(),
|
lambda self: (self.observation_space.sample(), {}),
|
||||||
"The `reset` method does not provide a `seed` or `**kwargs` keyword argument.",
|
"The `reset` method does not provide a `seed` or `**kwargs` keyword argument.",
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
AssertionError,
|
AssertionError,
|
||||||
lambda self, seed, *_: self.observation_space.sample(),
|
lambda self, seed, *_: (self.observation_space.sample(), {}),
|
||||||
"Expects the random number generator to have been generated given a seed was passed to reset. Mostly likely the environment reset function does not call `super().reset(seed=seed)`.",
|
"Expects the random number generator to have been generated given a seed was passed to reset. Mostly likely the environment reset function does not call `super().reset(seed=seed)`.",
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
@@ -115,95 +118,125 @@ def test_check_reset_seed(test, func: callable, message: str):
|
|||||||
check_reset_seed(GenericTestEnv(reset_fn=func))
|
check_reset_seed(GenericTestEnv(reset_fn=func))
|
||||||
|
|
||||||
|
|
||||||
|
def _deprecated_return_info(
|
||||||
|
self, return_info: bool = False
|
||||||
|
) -> Union[Tuple[ObsType, dict], ObsType]:
|
||||||
|
"""function to simulate the signature and behavior of a `reset` function with the deprecated `return_info` optional argument"""
|
||||||
|
if return_info:
|
||||||
|
return self.observation_space.sample(), {}
|
||||||
|
else:
|
||||||
|
return self.observation_space.sample()
|
||||||
|
|
||||||
|
|
||||||
def _reset_var_keyword_kwargs(self, kwargs):
|
def _reset_var_keyword_kwargs(self, kwargs):
|
||||||
return self.observation_space.sample()
|
return self.observation_space.sample(), {}
|
||||||
|
|
||||||
|
|
||||||
def _reset_return_info_type(self, seed=None, return_info=False, options=None):
|
def _reset_return_info_type(self, seed=None, options=None):
|
||||||
if return_info:
|
"""Returns a `list` instead of a `tuple`. This function is used to make sure `env_checker` correctly
|
||||||
return [1, 2]
|
checks that the return type of `env.reset()` is a `tuple`"""
|
||||||
else:
|
return [self.observation_space.sample(), {}]
|
||||||
return self.observation_space.sample()
|
|
||||||
|
|
||||||
|
|
||||||
def _reset_return_info_length(self, seed=None, return_info=False, options=None):
|
def _reset_return_info_length(self, seed=None, options=None):
|
||||||
if return_info:
|
return 1, 2, 3
|
||||||
return 1, 2, 3
|
|
||||||
else:
|
|
||||||
return self.observation_space.sample()
|
|
||||||
|
|
||||||
|
|
||||||
def _return_info_obs_outside(self, seed=None, return_info=False, options=None):
|
def _return_info_obs_outside(self, seed=None, options=None):
|
||||||
if return_info:
|
return self.observation_space.sample() + self.observation_space.high, {}
|
||||||
return self.observation_space.sample() + self.observation_space.high, {}
|
|
||||||
else:
|
|
||||||
return self.observation_space.sample()
|
|
||||||
|
|
||||||
|
|
||||||
def _return_info_not_dict(self, seed=None, return_info=False, options=None):
|
def _return_info_not_dict(self, seed=None, options=None):
|
||||||
if return_info:
|
return self.observation_space.sample(), ["key", "value"]
|
||||||
return self.observation_space.sample(), ["key", "value"]
|
|
||||||
else:
|
|
||||||
return self.observation_space.sample()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"test,func,message",
|
"test,func,message",
|
||||||
[
|
[
|
||||||
[
|
|
||||||
gym.error.Error,
|
|
||||||
lambda self, *_: self.observation_space.sample(),
|
|
||||||
"The `reset` method does not provide a `return_info` or `**kwargs` keyword argument.",
|
|
||||||
],
|
|
||||||
[
|
|
||||||
gym.error.Error,
|
|
||||||
_reset_var_keyword_kwargs,
|
|
||||||
"The `reset` method does not provide a `return_info` or `**kwargs` keyword argument.",
|
|
||||||
],
|
|
||||||
[
|
[
|
||||||
AssertionError,
|
AssertionError,
|
||||||
_reset_return_info_type,
|
_reset_return_info_type,
|
||||||
"Calling the reset method with `return_info=True` did not return a tuple, actual type: <class 'list'>",
|
"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `<class 'list'>`",
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
AssertionError,
|
AssertionError,
|
||||||
_reset_return_info_length,
|
_reset_return_info_length,
|
||||||
"Calling the reset method with `return_info=True` did not return a 2-tuple, actual length: 3",
|
"Calling the reset method did not return a 2-tuple, actual length: 3",
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
AssertionError,
|
AssertionError,
|
||||||
_return_info_obs_outside,
|
_return_info_obs_outside,
|
||||||
"The first element returned by `env.reset(return_info=True)` is not within the observation space.",
|
"The first element returned by `env.reset()` is not within the observation space.",
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
AssertionError,
|
AssertionError,
|
||||||
_return_info_not_dict,
|
_return_info_not_dict,
|
||||||
"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: <class 'list'>",
|
"The second element returned by `env.reset()` was not a dictionary, actual type: <class 'list'>",
|
||||||
],
|
],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_check_reset_info(test, func: callable, message: str):
|
def test_check_reset_return_type(test, func: callable, message: str):
|
||||||
"""Tests the check reset info function works as expected."""
|
"""Tests the check `env.reset()` function has a correct return type."""
|
||||||
if test is UserWarning:
|
|
||||||
with pytest.warns(
|
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
||||||
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
|
check_reset_return_type(GenericTestEnv(reset_fn=func))
|
||||||
):
|
|
||||||
check_reset_info(GenericTestEnv(reset_fn=func))
|
|
||||||
else:
|
@pytest.mark.parametrize(
|
||||||
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
"test,func,message",
|
||||||
check_reset_info(GenericTestEnv(reset_fn=func))
|
[
|
||||||
|
[
|
||||||
|
UserWarning,
|
||||||
|
_deprecated_return_info,
|
||||||
|
"`return_info` is deprecated as an optional argument to `reset`. `reset`"
|
||||||
|
"should now always return `obs, info` where `obs` is an observation, and `info` is a dictionary"
|
||||||
|
"containing additional information.",
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_check_reset_return_info_deprecation(test, func: callable, message: str):
|
||||||
|
"""Tests that return_info has been correct deprecated as an argument to `env.reset()`."""
|
||||||
|
|
||||||
|
with pytest.warns(test, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"):
|
||||||
|
check_reset_return_info_deprecation(GenericTestEnv(reset_fn=func))
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_seed_deprecation():
|
||||||
|
"""Tests that `check_seed_deprecation()` throws a warning if `env.seed()` has not been removed."""
|
||||||
|
|
||||||
|
message = """Official support for the `seed` function is dropped. Standard practice is to reset gym environments using `env.reset(seed=<desired seed>)`"""
|
||||||
|
|
||||||
|
env = GenericTestEnv()
|
||||||
|
|
||||||
|
def seed(seed):
|
||||||
|
return
|
||||||
|
|
||||||
|
with pytest.warns(
|
||||||
|
UserWarning, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"
|
||||||
|
):
|
||||||
|
env.seed = seed
|
||||||
|
assert callable(env.seed)
|
||||||
|
check_seed_deprecation(env)
|
||||||
|
|
||||||
|
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||||
|
env.seed = []
|
||||||
|
check_seed_deprecation(env)
|
||||||
|
env.seed = 123
|
||||||
|
check_seed_deprecation(env)
|
||||||
|
del env.seed
|
||||||
|
check_seed_deprecation(env)
|
||||||
|
assert len(caught_warnings) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_check_reset_options():
|
def test_check_reset_options():
|
||||||
"""Tests the check_reset_options function."""
|
"""Tests the check_reset_options function."""
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
gym.error.Error,
|
gym.error.Error,
|
||||||
match=re.escape(
|
match=re.escape(
|
||||||
"The `reset` method does not provide an `options` or `**kwargs` keyword argument"
|
"The `reset` method does not provide an `options` or `**kwargs` keyword argument"
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
check_reset_options(GenericTestEnv(reset_fn=lambda self: 0))
|
check_reset_options(GenericTestEnv(reset_fn=lambda self: (0, {})))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@@ -242,24 +242,20 @@ def test_check_obs(test, obs, obs_space: spaces.Space, message: str):
|
|||||||
assert len(warnings) == 0
|
assert len(warnings) == 0
|
||||||
|
|
||||||
|
|
||||||
def _reset_no_seed(self, return_info=False, options=None):
|
def _reset_no_seed(self, options=None):
|
||||||
return self.observation_space.sample()
|
return self.observation_space.sample(), {}
|
||||||
|
|
||||||
|
|
||||||
def _reset_seed_default(self, seed="error", return_info=False, options=None):
|
def _reset_seed_default(self, seed="error", options=None):
|
||||||
return self.observation_space.sample()
|
return self.observation_space.sample(), {}
|
||||||
|
|
||||||
|
|
||||||
def _reset_no_return_info(self, seed=None, options=None):
|
def _reset_no_option(self, seed=None):
|
||||||
return self.observation_space.sample()
|
return self.observation_space.sample(), {}
|
||||||
|
|
||||||
|
|
||||||
def _reset_no_option(self, seed=None, return_info=False):
|
|
||||||
return self.observation_space.sample()
|
|
||||||
|
|
||||||
|
|
||||||
def _make_reset_results(results):
|
def _make_reset_results(results):
|
||||||
def _reset_result(self, seed=None, return_info=False, options=None):
|
def _reset_result(self, seed=None, options=None):
|
||||||
return results
|
return results
|
||||||
|
|
||||||
return _reset_result
|
return _reset_result
|
||||||
@@ -280,12 +276,6 @@ def _make_reset_results(results):
|
|||||||
"The default seed argument in `Env.reset` should be `None`, otherwise the environment will by default always be deterministic. Actual default: seed='error'",
|
"The default seed argument in `Env.reset` should be `None`, otherwise the environment will by default always be deterministic. Actual default: seed='error'",
|
||||||
{},
|
{},
|
||||||
],
|
],
|
||||||
[
|
|
||||||
UserWarning,
|
|
||||||
_reset_no_return_info,
|
|
||||||
"Future gym versions will require that `Env.reset` can be passed `return_info` to return information from the environment resetting.",
|
|
||||||
{},
|
|
||||||
],
|
|
||||||
[
|
[
|
||||||
UserWarning,
|
UserWarning,
|
||||||
_reset_no_option,
|
_reset_no_option,
|
||||||
@@ -293,16 +283,16 @@ def _make_reset_results(results):
|
|||||||
{},
|
{},
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
AssertionError,
|
UserWarning,
|
||||||
_make_reset_results([0, {}]),
|
_make_reset_results([0, {}]),
|
||||||
"The result returned by `env.reset(return_info=True)` was not a tuple, actual type: <class 'list'>",
|
"The result returned by `env.reset()` was not a tuple of the form `(obs, info)`, where `obs` is a observation and `info` is a dictionary containing additional information. Actual type: `<class 'list'>`",
|
||||||
{"return_info": True},
|
{},
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
AssertionError,
|
AssertionError,
|
||||||
_make_reset_results((0, {1, 2})),
|
_make_reset_results((np.array([0], dtype=np.float32), {1, 2})),
|
||||||
"The second element returned by `env.reset(return_info=True)` was not a dictionary, actual type: <class 'set'>",
|
"The second element returned by `env.reset()` was not a dictionary, actual type: <class 'set'>",
|
||||||
{"return_info": True},
|
{},
|
||||||
],
|
],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@@ -317,6 +307,8 @@ def test_passive_env_reset_checker(test, func: callable, message: str, kwargs: D
|
|||||||
with pytest.warns(None) as warnings:
|
with pytest.warns(None) as warnings:
|
||||||
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
with pytest.raises(test, match=f"^{re.escape(message)}$"):
|
||||||
env_reset_passive_checker(GenericTestEnv(reset_fn=func), **kwargs)
|
env_reset_passive_checker(GenericTestEnv(reset_fn=func), **kwargs)
|
||||||
|
for warning in warnings:
|
||||||
|
print(warning)
|
||||||
assert len(warnings) == 0
|
assert len(warnings) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@@ -28,7 +28,7 @@ def test_reset_async_vector_env(shared_memory):
|
|||||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||||
|
|
||||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||||
observations = env.reset()
|
observations, infos = env.reset()
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
@@ -40,19 +40,7 @@ def test_reset_async_vector_env(shared_memory):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||||
observations = env.reset(return_info=False)
|
observations, infos = env.reset()
|
||||||
finally:
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
assert isinstance(env.observation_space, Box)
|
|
||||||
assert isinstance(observations, np.ndarray)
|
|
||||||
assert observations.dtype == env.observation_space.dtype
|
|
||||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
|
||||||
assert observations.shape == env.observation_space.shape
|
|
||||||
|
|
||||||
try:
|
|
||||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
|
||||||
observations, infos = env.reset(return_info=True)
|
|
||||||
finally:
|
finally:
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
@@ -143,7 +131,7 @@ def test_copy_async_vector_env(shared_memory):
|
|||||||
|
|
||||||
# TODO, these tests do nothing, understand the purpose of the tests and fix them
|
# TODO, these tests do nothing, understand the purpose of the tests and fix them
|
||||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=True)
|
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=True)
|
||||||
observations = env.reset()
|
observations, infos = env.reset()
|
||||||
observations[0] = 0
|
observations[0] = 0
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
@@ -155,7 +143,7 @@ def test_no_copy_async_vector_env(shared_memory):
|
|||||||
|
|
||||||
# TODO, these tests do nothing, understand the purpose of the tests and fix them
|
# TODO, these tests do nothing, understand the purpose of the tests and fix them
|
||||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=False)
|
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=False)
|
||||||
observations = env.reset()
|
observations, infos = env.reset()
|
||||||
observations[0] = 0
|
observations[0] = 0
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
@@ -268,7 +256,7 @@ def test_custom_space_async_vector_env():
|
|||||||
env_fns = [make_custom_space_env(i) for i in range(4)]
|
env_fns = [make_custom_space_env(i) for i in range(4)]
|
||||||
|
|
||||||
env = AsyncVectorEnv(env_fns, shared_memory=False)
|
env = AsyncVectorEnv(env_fns, shared_memory=False)
|
||||||
reset_observations = env.reset()
|
reset_observations, reset_infos = env.reset()
|
||||||
|
|
||||||
assert isinstance(env.single_action_space, CustomSpace)
|
assert isinstance(env.single_action_space, CustomSpace)
|
||||||
assert isinstance(env.action_space, Tuple)
|
assert isinstance(env.action_space, Tuple)
|
||||||
|
@@ -12,7 +12,7 @@ class OldStepEnv(gym.Env):
|
|||||||
self.observation_space = Discrete(2)
|
self.observation_space = Discrete(2)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
return 0
|
return 0, {}
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
obs = self.observation_space.sample()
|
obs = self.observation_space.sample()
|
||||||
@@ -28,7 +28,7 @@ class NewStepEnv(gym.Env):
|
|||||||
self.observation_space = Discrete(2)
|
self.observation_space = Discrete(2)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
return 0
|
return 0, {}
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
obs = self.observation_space.sample()
|
obs = self.observation_space.sample()
|
||||||
|
@@ -24,7 +24,7 @@ def test_create_sync_vector_env():
|
|||||||
def test_reset_sync_vector_env():
|
def test_reset_sync_vector_env():
|
||||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||||
env = SyncVectorEnv(env_fns)
|
env = SyncVectorEnv(env_fns)
|
||||||
observations = env.reset()
|
observations, infos = env.reset()
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
assert isinstance(env.observation_space, Box)
|
assert isinstance(env.observation_space, Box)
|
||||||
@@ -35,32 +35,6 @@ def test_reset_sync_vector_env():
|
|||||||
|
|
||||||
del observations
|
del observations
|
||||||
|
|
||||||
env = SyncVectorEnv(env_fns)
|
|
||||||
observations = env.reset(return_info=False)
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
assert isinstance(env.observation_space, Box)
|
|
||||||
assert isinstance(observations, np.ndarray)
|
|
||||||
assert observations.dtype == env.observation_space.dtype
|
|
||||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
|
||||||
assert observations.shape == env.observation_space.shape
|
|
||||||
|
|
||||||
del observations
|
|
||||||
|
|
||||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
|
||||||
|
|
||||||
env = SyncVectorEnv(env_fns)
|
|
||||||
observations, infos = env.reset(return_info=True)
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
assert isinstance(env.observation_space, Box)
|
|
||||||
assert isinstance(observations, np.ndarray)
|
|
||||||
assert observations.dtype == env.observation_space.dtype
|
|
||||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
|
||||||
assert observations.shape == env.observation_space.shape
|
|
||||||
assert isinstance(infos, dict)
|
|
||||||
assert all([isinstance(info, dict) for info in infos])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_single_action_space", [True, False])
|
@pytest.mark.parametrize("use_single_action_space", [True, False])
|
||||||
def test_step_sync_vector_env(use_single_action_space):
|
def test_step_sync_vector_env(use_single_action_space):
|
||||||
@@ -145,7 +119,7 @@ def test_custom_space_sync_vector_env():
|
|||||||
env_fns = [make_custom_space_env(i) for i in range(4)]
|
env_fns = [make_custom_space_env(i) for i in range(4)]
|
||||||
|
|
||||||
env = SyncVectorEnv(env_fns)
|
env = SyncVectorEnv(env_fns)
|
||||||
reset_observations = env.reset()
|
reset_observations, infos = env.reset()
|
||||||
|
|
||||||
assert isinstance(env.single_action_space, CustomSpace)
|
assert isinstance(env.single_action_space, CustomSpace)
|
||||||
assert isinstance(env.action_space, Tuple)
|
assert isinstance(env.action_space, Tuple)
|
||||||
|
@@ -22,8 +22,8 @@ def test_vector_env_equal(shared_memory):
|
|||||||
assert async_env.action_space == sync_env.action_space
|
assert async_env.action_space == sync_env.action_space
|
||||||
assert async_env.single_action_space == sync_env.single_action_space
|
assert async_env.single_action_space == sync_env.single_action_space
|
||||||
|
|
||||||
async_observations = async_env.reset(seed=0)
|
async_observations, async_infos = async_env.reset(seed=0)
|
||||||
sync_observations = sync_env.reset(seed=0)
|
sync_observations, sync_infos = sync_env.reset(seed=0)
|
||||||
assert np.all(async_observations == sync_observations)
|
assert np.all(async_observations == sync_observations)
|
||||||
|
|
||||||
for _ in range(num_steps):
|
for _ in range(num_steps):
|
||||||
|
@@ -63,7 +63,7 @@ class UnittestSlowEnv(gym.Env):
|
|||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
if self.slow_reset > 0:
|
if self.slow_reset > 0:
|
||||||
time.sleep(self.slow_reset)
|
time.sleep(self.slow_reset)
|
||||||
return self.observation_space.sample()
|
return self.observation_space.sample(), {}
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
time.sleep(action)
|
time.sleep(action)
|
||||||
@@ -99,7 +99,7 @@ class CustomSpaceEnv(gym.Env):
|
|||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
return "reset"
|
return "reset", {}
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
observation = f"step({action:s})"
|
observation = f"step({action:s})"
|
||||||
|
@@ -86,9 +86,7 @@ def test_atari_preprocessing_grayscale(env, obs_shape):
|
|||||||
# It is not possible to test the outputs as we are not using actual observations.
|
# It is not possible to test the outputs as we are not using actual observations.
|
||||||
# todo: update when ale-py is compatible with the ci
|
# todo: update when ale-py is compatible with the ci
|
||||||
|
|
||||||
obs = env.reset(seed=0)
|
obs, _ = env.reset(seed=0)
|
||||||
assert obs in env.observation_space
|
|
||||||
obs, _ = env.reset(seed=0, return_info=True)
|
|
||||||
assert obs in env.observation_space
|
assert obs in env.observation_space
|
||||||
|
|
||||||
obs, _, _, _ = env.step(env.action_space.sample())
|
obs, _, _, _ = env.step(env.action_space.sample())
|
||||||
@@ -110,7 +108,7 @@ def test_atari_preprocessing_scale(grayscale, scaled, max_test_steps=10):
|
|||||||
noop_max=0,
|
noop_max=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
obs = env.reset()
|
obs, _ = env.reset()
|
||||||
|
|
||||||
max_obs = 1 if scaled else 255
|
max_obs = 1 if scaled else 255
|
||||||
assert np.all(0 <= obs) and np.all(obs <= max_obs)
|
assert np.all(0 <= obs) and np.all(obs <= max_obs)
|
||||||
|
@@ -39,19 +39,10 @@ class DummyResetEnv(gym.Env):
|
|||||||
{"count": self.count}, # Info
|
{"count": self.count}, # Info
|
||||||
)
|
)
|
||||||
|
|
||||||
def reset(
|
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||||
self,
|
|
||||||
*,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
return_info: Optional[bool] = False,
|
|
||||||
options: Optional[dict] = None
|
|
||||||
):
|
|
||||||
"""Resets the DummyEnv to return the count array and info with count."""
|
"""Resets the DummyEnv to return the count array and info with count."""
|
||||||
self.count = 0
|
self.count = 0
|
||||||
if not return_info:
|
return np.array([self.count]), {"count": self.count}
|
||||||
return np.array([self.count])
|
|
||||||
else:
|
|
||||||
return np.array([self.count]), {"count": self.count}
|
|
||||||
|
|
||||||
|
|
||||||
def unwrap_env(env) -> Generator[gym.Wrapper, None, None]:
|
def unwrap_env(env) -> Generator[gym.Wrapper, None, None]:
|
||||||
@@ -113,7 +104,7 @@ def test_autoreset_wrapper_autoreset():
|
|||||||
env = DummyResetEnv()
|
env = DummyResetEnv()
|
||||||
env = AutoResetWrapper(env)
|
env = AutoResetWrapper(env)
|
||||||
|
|
||||||
obs, info = env.reset(return_info=True)
|
obs, info = env.reset()
|
||||||
assert obs == np.array([0])
|
assert obs == np.array([0])
|
||||||
assert info == {"count": 0}
|
assert info == {"count": 0}
|
||||||
|
|
||||||
|
@@ -25,16 +25,10 @@ class FakeEnvironment(gym.Env):
|
|||||||
image_shape = (32, 32, 3)
|
image_shape = (32, 32, 3)
|
||||||
return np.zeros(image_shape, dtype=np.uint8)
|
return np.zeros(image_shape, dtype=np.uint8)
|
||||||
|
|
||||||
def reset(
|
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||||
self,
|
|
||||||
*,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
return_info: bool = False,
|
|
||||||
options: Optional[dict] = None
|
|
||||||
):
|
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
observation = self.observation_space.sample()
|
observation = self.observation_space.sample()
|
||||||
return observation if not return_info else (observation, {})
|
return observation, {}
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
del action
|
del action
|
||||||
@@ -79,8 +73,9 @@ class TestFilterObservation:
|
|||||||
assert tuple(wrapped_env.observation_space.spaces.keys()) == tuple(filter_keys)
|
assert tuple(wrapped_env.observation_space.spaces.keys()) == tuple(filter_keys)
|
||||||
|
|
||||||
# Check that the added space item is consistent with the added observation.
|
# Check that the added space item is consistent with the added observation.
|
||||||
observation = wrapped_env.reset()
|
observation, info = wrapped_env.reset()
|
||||||
assert len(observation) == len(filter_keys)
|
assert len(observation) == len(filter_keys)
|
||||||
|
assert isinstance(info, dict)
|
||||||
|
|
||||||
@pytest.mark.parametrize("filter_keys,error_type,error_match", ERROR_TEST_CASES)
|
@pytest.mark.parametrize("filter_keys,error_type,error_match", ERROR_TEST_CASES)
|
||||||
def test_raises_with_incorrect_arguments(
|
def test_raises_with_incorrect_arguments(
|
||||||
|
@@ -18,7 +18,7 @@ class FakeEnvironment(gym.Env):
|
|||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
self.observation = self.observation_space.sample()
|
self.observation = self.observation_space.sample()
|
||||||
return self.observation
|
return self.observation, {}
|
||||||
|
|
||||||
|
|
||||||
OBSERVATION_SPACES = (
|
OBSERVATION_SPACES = (
|
||||||
@@ -67,7 +67,7 @@ class TestFlattenEnvironment:
|
|||||||
"""
|
"""
|
||||||
env = FakeEnvironment(observation_space=observation_space)
|
env = FakeEnvironment(observation_space=observation_space)
|
||||||
wrapped_env = FlattenObservation(env)
|
wrapped_env = FlattenObservation(env)
|
||||||
flattened = wrapped_env.reset()
|
flattened, info = wrapped_env.reset()
|
||||||
|
|
||||||
unflattened = unflatten(env.observation_space, flattened)
|
unflattened = unflatten(env.observation_space, flattened)
|
||||||
original = env.observation
|
original = env.observation
|
||||||
|
@@ -11,11 +11,13 @@ def test_flatten_observation(env_id):
|
|||||||
env = gym.make(env_id, disable_env_checker=True)
|
env = gym.make(env_id, disable_env_checker=True)
|
||||||
wrapped_env = FlattenObservation(env)
|
wrapped_env = FlattenObservation(env)
|
||||||
|
|
||||||
obs = env.reset()
|
obs, info = env.reset()
|
||||||
wrapped_obs = wrapped_env.reset()
|
wrapped_obs, wrapped_obs_info = wrapped_env.reset()
|
||||||
|
|
||||||
space = spaces.Tuple((spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2)))
|
space = spaces.Tuple((spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2)))
|
||||||
wrapped_space = spaces.Box(0, 1, [32 + 11 + 2], dtype=np.int64)
|
wrapped_space = spaces.Box(0, 1, [32 + 11 + 2], dtype=np.int64)
|
||||||
|
|
||||||
assert space.contains(obs)
|
assert space.contains(obs)
|
||||||
assert wrapped_space.contains(wrapped_obs)
|
assert wrapped_space.contains(wrapped_obs)
|
||||||
|
assert isinstance(info, dict)
|
||||||
|
assert isinstance(wrapped_obs_info, dict)
|
||||||
|
@@ -33,8 +33,8 @@ def test_frame_stack(env_id, num_stack, lz4_compress):
|
|||||||
|
|
||||||
dup = gym.make(env_id, disable_env_checker=True)
|
dup = gym.make(env_id, disable_env_checker=True)
|
||||||
|
|
||||||
obs = env.reset(seed=0)
|
obs, _ = env.reset(seed=0)
|
||||||
dup_obs = dup.reset(seed=0)
|
dup_obs, _ = dup.reset(seed=0)
|
||||||
assert np.allclose(obs[-1], dup_obs)
|
assert np.allclose(obs[-1], dup_obs)
|
||||||
|
|
||||||
for _ in range(num_stack**2):
|
for _ in range(num_stack**2):
|
||||||
|
@@ -22,5 +22,5 @@ def test_gray_scale_observation(env_id, keep_dim):
|
|||||||
else:
|
else:
|
||||||
assert len(wrapped_env.observation_space.shape) == 2
|
assert len(wrapped_env.observation_space.shape) == 2
|
||||||
|
|
||||||
wrapped_obs = wrapped_env.reset()
|
wrapped_obs, info = wrapped_env.reset()
|
||||||
assert wrapped_obs in wrapped_env.observation_space
|
assert wrapped_obs in wrapped_env.observation_space
|
||||||
|
@@ -23,7 +23,7 @@ class FakeEnvironment(gym.Env):
|
|||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
observation = self.observation_space.sample()
|
observation = self.observation_space.sample()
|
||||||
return observation
|
return observation, {}
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
del action
|
del action
|
||||||
@@ -115,5 +115,6 @@ class TestNestedDictWrapper:
|
|||||||
def test_nested_dicts_ravel(self, observation_space, flat_shape):
|
def test_nested_dicts_ravel(self, observation_space, flat_shape):
|
||||||
env = FakeEnvironment(observation_space=observation_space)
|
env = FakeEnvironment(observation_space=observation_space)
|
||||||
wrapped_env = FlattenObservation(FilterObservation(env, env.obs_keys))
|
wrapped_env = FlattenObservation(FilterObservation(env, env.obs_keys))
|
||||||
obs = wrapped_env.reset()
|
obs, info = wrapped_env.reset()
|
||||||
assert obs.shape == wrapped_env.observation_space.shape
|
assert obs.shape == wrapped_env.observation_space.shape
|
||||||
|
assert isinstance(info, dict)
|
||||||
|
@@ -1,7 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
|
||||||
from numpy.testing import assert_almost_equal
|
from numpy.testing import assert_almost_equal
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
@@ -24,19 +23,10 @@ class DummyRewardEnv(gym.Env):
|
|||||||
self.t += 1
|
self.t += 1
|
||||||
return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {}
|
return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {}
|
||||||
|
|
||||||
def reset(
|
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||||
self,
|
|
||||||
*,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
return_info: Optional[bool] = False,
|
|
||||||
options: Optional[dict] = None
|
|
||||||
):
|
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
self.t = self.return_reward_idx
|
self.t = self.return_reward_idx
|
||||||
if not return_info:
|
return np.array([self.t]), {}
|
||||||
return np.array([self.t])
|
|
||||||
else:
|
|
||||||
return np.array([self.t]), {}
|
|
||||||
|
|
||||||
|
|
||||||
def make_env(return_reward_idx):
|
def make_env(return_reward_idx):
|
||||||
@@ -47,11 +37,10 @@ def make_env(return_reward_idx):
|
|||||||
return thunk
|
return thunk
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("return_info", [False, True])
|
def test_normalize_observation():
|
||||||
def test_normalize_observation(return_info: bool):
|
|
||||||
env = DummyRewardEnv(return_reward_idx=0)
|
env = DummyRewardEnv(return_reward_idx=0)
|
||||||
env = NormalizeObservation(env)
|
env = NormalizeObservation(env)
|
||||||
env.reset(return_info=return_info)
|
env.reset()
|
||||||
env.step(env.action_space.sample())
|
env.step(env.action_space.sample())
|
||||||
assert_almost_equal(env.obs_rms.mean, 0.5, decimal=4)
|
assert_almost_equal(env.obs_rms.mean, 0.5, decimal=4)
|
||||||
env.step(env.action_space.sample())
|
env.step(env.action_space.sample())
|
||||||
@@ -61,13 +50,7 @@ def test_normalize_observation(return_info: bool):
|
|||||||
def test_normalize_reset_info():
|
def test_normalize_reset_info():
|
||||||
env = DummyRewardEnv(return_reward_idx=0)
|
env = DummyRewardEnv(return_reward_idx=0)
|
||||||
env = NormalizeObservation(env)
|
env = NormalizeObservation(env)
|
||||||
obs = env.reset()
|
obs, info = env.reset()
|
||||||
assert isinstance(obs, np.ndarray)
|
|
||||||
del obs
|
|
||||||
obs = env.reset(return_info=False)
|
|
||||||
assert isinstance(obs, np.ndarray)
|
|
||||||
del obs
|
|
||||||
obs, info = env.reset(return_info=True)
|
|
||||||
assert isinstance(obs, np.ndarray)
|
assert isinstance(obs, np.ndarray)
|
||||||
assert isinstance(info, dict)
|
assert isinstance(info, dict)
|
||||||
|
|
||||||
|
@@ -57,8 +57,8 @@ def test_initialise_failures(env, message):
|
|||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
def _reset_failure(self, seed=None, return_info=False, options=None):
|
def _reset_failure(self, seed=None, options=None):
|
||||||
return np.array([-1.0], dtype=np.float32)
|
return np.array([-1.0], dtype=np.float32), {}
|
||||||
|
|
||||||
|
|
||||||
def _step_failure(self, action):
|
def _step_failure(self, action):
|
||||||
|
@@ -21,7 +21,7 @@ class FakeEnvironment(gym.Env):
|
|||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
observation = self.observation_space.sample()
|
observation = self.observation_space.sample()
|
||||||
return observation
|
return observation, {}
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
del action
|
del action
|
||||||
@@ -82,9 +82,10 @@ def test_dict_observation(pixels_only):
|
|||||||
assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys
|
assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys
|
||||||
|
|
||||||
# Check that the added space item is consistent with the added observation.
|
# Check that the added space item is consistent with the added observation.
|
||||||
observation = wrapped_env.reset()
|
observation, info = wrapped_env.reset()
|
||||||
rgb_observation = observation[pixel_key]
|
rgb_observation = observation[pixel_key]
|
||||||
|
|
||||||
|
assert isinstance(info, dict)
|
||||||
assert rgb_observation.shape == (height, width, 3)
|
assert rgb_observation.shape == (height, width, 3)
|
||||||
assert rgb_observation.dtype == np.uint8
|
assert rgb_observation.dtype == np.uint8
|
||||||
|
|
||||||
@@ -113,9 +114,10 @@ def test_single_array_observation(pixels_only):
|
|||||||
pixel_key,
|
pixel_key,
|
||||||
]
|
]
|
||||||
|
|
||||||
observation = wrapped_env.reset()
|
observation, info = wrapped_env.reset()
|
||||||
depth_observation = observation[pixel_key]
|
depth_observation = observation[pixel_key]
|
||||||
|
|
||||||
|
assert isinstance(info, dict)
|
||||||
assert depth_observation.shape == (32, 32, 3)
|
assert depth_observation.shape == (32, 32, 3)
|
||||||
assert depth_observation.dtype == np.uint8
|
assert depth_observation.dtype == np.uint8
|
||||||
|
|
||||||
|
@@ -31,10 +31,7 @@ def test_record_episode_statistics_reset_info():
|
|||||||
env = gym.make("CartPole-v1", disable_env_checker=True)
|
env = gym.make("CartPole-v1", disable_env_checker=True)
|
||||||
env = RecordEpisodeStatistics(env)
|
env = RecordEpisodeStatistics(env)
|
||||||
ob_space = env.observation_space
|
ob_space = env.observation_space
|
||||||
obs = env.reset()
|
obs, info = env.reset()
|
||||||
assert ob_space.contains(obs)
|
|
||||||
del obs
|
|
||||||
obs, info = env.reset(return_info=True)
|
|
||||||
assert ob_space.contains(obs)
|
assert ob_space.contains(obs)
|
||||||
assert isinstance(info, dict)
|
assert isinstance(info, dict)
|
||||||
|
|
||||||
|
@@ -23,35 +23,17 @@ def test_record_video_using_default_trigger():
|
|||||||
shutil.rmtree("videos")
|
shutil.rmtree("videos")
|
||||||
|
|
||||||
|
|
||||||
def test_record_video_reset_return_info():
|
def test_record_video_reset():
|
||||||
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
|
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
|
||||||
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
|
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
|
||||||
ob_space = env.observation_space
|
ob_space = env.observation_space
|
||||||
obs, info = env.reset(return_info=True)
|
obs, info = env.reset()
|
||||||
env.close()
|
env.close()
|
||||||
assert os.path.isdir("videos")
|
assert os.path.isdir("videos")
|
||||||
shutil.rmtree("videos")
|
shutil.rmtree("videos")
|
||||||
assert ob_space.contains(obs)
|
assert ob_space.contains(obs)
|
||||||
assert isinstance(info, dict)
|
assert isinstance(info, dict)
|
||||||
|
|
||||||
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
|
|
||||||
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
|
|
||||||
ob_space = env.observation_space
|
|
||||||
obs = env.reset(return_info=False)
|
|
||||||
env.close()
|
|
||||||
assert os.path.isdir("videos")
|
|
||||||
shutil.rmtree("videos")
|
|
||||||
assert ob_space.contains(obs)
|
|
||||||
|
|
||||||
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
|
|
||||||
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
|
|
||||||
ob_space = env.observation_space
|
|
||||||
obs = env.reset()
|
|
||||||
env.close()
|
|
||||||
assert os.path.isdir("videos")
|
|
||||||
shutil.rmtree("videos")
|
|
||||||
assert ob_space.contains(obs)
|
|
||||||
|
|
||||||
|
|
||||||
def test_record_video_step_trigger():
|
def test_record_video_step_trigger():
|
||||||
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
|
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
|
||||||
|
@@ -18,8 +18,8 @@ def test_rescale_action():
|
|||||||
|
|
||||||
seed = 0
|
seed = 0
|
||||||
|
|
||||||
obs = env.reset(seed=seed)
|
obs, info = env.reset(seed=seed)
|
||||||
wrapped_obs = wrapped_env.reset(seed=seed)
|
wrapped_obs, wrapped_obs_info = wrapped_env.reset(seed=seed)
|
||||||
assert np.allclose(obs, wrapped_obs)
|
assert np.allclose(obs, wrapped_obs)
|
||||||
|
|
||||||
obs, reward, _, _ = env.step([1.5])
|
obs, reward, _, _ = env.step([1.5])
|
||||||
|
@@ -13,7 +13,7 @@ def test_resize_observation(env_id, shape):
|
|||||||
|
|
||||||
assert isinstance(env.observation_space, spaces.Box)
|
assert isinstance(env.observation_space, spaces.Box)
|
||||||
assert env.observation_space.shape[-1] == 3
|
assert env.observation_space.shape[-1] == 3
|
||||||
obs = env.reset()
|
obs, _ = env.reset()
|
||||||
if isinstance(shape, int):
|
if isinstance(shape, int):
|
||||||
assert env.observation_space.shape[:2] == (shape, shape)
|
assert env.observation_space.shape[:2] == (shape, shape)
|
||||||
assert obs.shape == (shape, shape, 3)
|
assert obs.shape == (shape, shape, 3)
|
||||||
|
@@ -14,8 +14,8 @@ def test_time_aware_observation(env_id):
|
|||||||
assert isinstance(wrapped_env.observation_space, spaces.Box)
|
assert isinstance(wrapped_env.observation_space, spaces.Box)
|
||||||
assert wrapped_env.observation_space.shape[0] == env.observation_space.shape[0] + 1
|
assert wrapped_env.observation_space.shape[0] == env.observation_space.shape[0] + 1
|
||||||
|
|
||||||
obs = env.reset()
|
obs, info = env.reset()
|
||||||
wrapped_obs = wrapped_env.reset()
|
wrapped_obs, wrapped_obs_info = wrapped_env.reset()
|
||||||
assert wrapped_env.t == 0.0
|
assert wrapped_env.t == 0.0
|
||||||
assert wrapped_obs[-1] == 0.0
|
assert wrapped_obs[-1] == 0.0
|
||||||
assert wrapped_obs.shape[0] == obs.shape[0] + 1
|
assert wrapped_obs.shape[0] == obs.shape[0] + 1
|
||||||
@@ -30,7 +30,7 @@ def test_time_aware_observation(env_id):
|
|||||||
assert wrapped_obs[-1] == 2.0
|
assert wrapped_obs[-1] == 2.0
|
||||||
assert wrapped_obs.shape[0] == obs.shape[0] + 1
|
assert wrapped_obs.shape[0] == obs.shape[0] + 1
|
||||||
|
|
||||||
wrapped_obs = wrapped_env.reset()
|
wrapped_obs, wrapped_obs_info = wrapped_env.reset()
|
||||||
assert wrapped_env.t == 0.0
|
assert wrapped_env.t == 0.0
|
||||||
assert wrapped_obs[-1] == 0.0
|
assert wrapped_obs[-1] == 0.0
|
||||||
assert wrapped_obs.shape[0] == obs.shape[0] + 1
|
assert wrapped_obs.shape[0] == obs.shape[0] + 1
|
||||||
|
@@ -9,13 +9,7 @@ def test_time_limit_reset_info():
|
|||||||
env = gym.make("CartPole-v1", disable_env_checker=True)
|
env = gym.make("CartPole-v1", disable_env_checker=True)
|
||||||
env = TimeLimit(env)
|
env = TimeLimit(env)
|
||||||
ob_space = env.observation_space
|
ob_space = env.observation_space
|
||||||
obs = env.reset()
|
obs, info = env.reset()
|
||||||
assert ob_space.contains(obs)
|
|
||||||
del obs
|
|
||||||
obs = env.reset(return_info=False)
|
|
||||||
assert ob_space.contains(obs)
|
|
||||||
del obs
|
|
||||||
obs, info = env.reset(return_info=True)
|
|
||||||
assert ob_space.contains(obs)
|
assert ob_space.contains(obs)
|
||||||
assert isinstance(info, dict)
|
assert isinstance(info, dict)
|
||||||
|
|
||||||
|
@@ -15,9 +15,10 @@ def test_transform_observation(env_id):
|
|||||||
gym.make(env_id, disable_env_checker=True), lambda obs: affine_transform(obs)
|
gym.make(env_id, disable_env_checker=True), lambda obs: affine_transform(obs)
|
||||||
)
|
)
|
||||||
|
|
||||||
obs = env.reset(seed=0)
|
obs, info = env.reset(seed=0)
|
||||||
wrapped_obs = wrapped_env.reset(seed=0)
|
wrapped_obs, wrapped_obs_info = wrapped_env.reset(seed=0)
|
||||||
assert np.allclose(wrapped_obs, affine_transform(obs))
|
assert np.allclose(wrapped_obs, affine_transform(obs))
|
||||||
|
assert isinstance(wrapped_obs_info, dict)
|
||||||
|
|
||||||
action = env.action_space.sample()
|
action = env.action_space.sample()
|
||||||
obs, reward, done, _ = env.step(action)
|
obs, reward, done, _ = env.step(action)
|
||||||
|
@@ -23,7 +23,7 @@ def test_info_to_list():
|
|||||||
env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True)
|
env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True)
|
||||||
wrapped_env = VectorListInfo(env_to_wrap)
|
wrapped_env = VectorListInfo(env_to_wrap)
|
||||||
wrapped_env.action_space.seed(SEED)
|
wrapped_env.action_space.seed(SEED)
|
||||||
_, info = wrapped_env.reset(seed=SEED, return_info=True)
|
_, info = wrapped_env.reset(seed=SEED)
|
||||||
assert isinstance(info, list)
|
assert isinstance(info, list)
|
||||||
assert len(info) == NUM_ENVS
|
assert len(info) == NUM_ENVS
|
||||||
|
|
||||||
@@ -40,7 +40,7 @@ def test_info_to_list():
|
|||||||
def test_info_to_list_statistics():
|
def test_info_to_list_statistics():
|
||||||
env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True)
|
env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True)
|
||||||
wrapped_env = VectorListInfo(RecordEpisodeStatistics(env_to_wrap))
|
wrapped_env = VectorListInfo(RecordEpisodeStatistics(env_to_wrap))
|
||||||
_, info = wrapped_env.reset(seed=SEED, return_info=True)
|
_, info = wrapped_env.reset(seed=SEED)
|
||||||
wrapped_env.action_space.seed(SEED)
|
wrapped_env.action_space.seed(SEED)
|
||||||
assert isinstance(info, list)
|
assert isinstance(info, list)
|
||||||
assert len(info) == NUM_ENVS
|
assert len(info) == NUM_ENVS
|
||||||
|
Reference in New Issue
Block a user