Adding return_info argument to reset to allow for optional info dict as a second return value (#2546)

* initial draft of optional info dict in reset function, implemented for cartpole, tests seem to be passing

* merged core.py

* updated return type annotation for reset function in core.py

* optional metadata with return_info from reset added for all first party environments, with corresponding tests. Incomplete implementation for wrappers and vector wrappers

* removed Optional type for return_info arguments

* added tests for return_info to normalize wrapper and sync_vector_env

* autoformatted using black

* added optional reset metadata tests to several wrappers

* added return_info capability to async_vector_env.py and test to verify functionality

* added optional return_info test for record_video.py

* removed tests for mujoco environments

* autoformatted

* improved test coverage for optional reset return_info

* re-removed unit test envs accidentally reintroduced in merge

* removed unnecessary import

* changes based on code-review

* small fix to core wrapper typing and autoformatted record_epsisode_stats

* small change to pass flake8 style
This commit is contained in:
John Balis
2022-02-06 17:28:27 -06:00
committed by GitHub
parent 62e52727d5
commit 15049e22d7
27 changed files with 441 additions and 62 deletions

View File

@@ -1,8 +1,7 @@
from __future__ import annotations
from abc import abstractmethod
from typing import TypeVar, Generic, Tuple, SupportsFloat
from typing import Optional
from typing import TypeVar, Generic, Tuple, Union, Optional, SupportsFloat
import gym
from gym import error, spaces
@@ -71,8 +70,12 @@ class Env(Generic[ObsType, ActType]):
@abstractmethod
def reset(
self, *, seed: Optional[int] = None, options: Optional[dict] = None
) -> ObsType:
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
) -> Union[ObsType, tuple[ObsType, dict]]:
"""Resets the environment to an initial state and returns an initial
observation.
@@ -84,6 +87,7 @@ class Env(Generic[ObsType, ActType]):
Returns:
observation (object): the initial observation.
info (optional dictionary): a dictionary containing extra information, this is only returned if return_info is set to true
"""
# Initialize the RNG if it's the first reset, or if the seed is manually passed
if seed is not None or self.np_random is None:
@@ -262,7 +266,7 @@ class Wrapper(Env[ObsType, ActType]):
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
return self.env.step(action)
def reset(self, **kwargs) -> ObsType:
def reset(self, **kwargs) -> Union[ObsType, tuple[ObsType, dict]]:
return self.env.reset(**kwargs)
def render(self, mode="human", **kwargs):

View File

@@ -350,7 +350,13 @@ class BipedalWalker(gym.Env, EzPickle):
x2 = max(p[0] for p in poly)
self.cloud_poly.append((poly, x1, x2))
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed)
self._destroy()
self.world.contactListener_bug_workaround = ContactDetector(self)
@@ -436,8 +442,10 @@ class BipedalWalker(gym.Env, EzPickle):
return fraction
self.lidar = [LidarCallback() for _ in range(10)]
return self.step(np.array([0, 0, 0, 0]))[0]
if not return_info:
return self.step(np.array([0, 0, 0, 0]))[0]
else:
return self.step(np.array([0, 0, 0, 0]))[0], {}
def step(self, action):
# self.hull.ApplyForceToCenter((0, 20), True) -- Uncomment this to receive a bit of stability help

View File

@@ -374,7 +374,13 @@ class CarRacing(gym.Env, EzPickle):
self.track = track
return True
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed)
self._destroy()
self.reward = 0.0
@@ -395,7 +401,10 @@ class CarRacing(gym.Env, EzPickle):
)
self.car = Car(self.world, *self.track[0][1:4])
return self.step(None)[0]
if not return_info:
return self.step(None)[0]
else:
return self.step(None)[0], {}
def step(self, action):
if action is not None:

View File

@@ -183,7 +183,13 @@ class LunarLander(gym.Env, EzPickle):
self.world.DestroyBody(self.legs[0])
self.world.DestroyBody(self.legs[1])
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed)
self._destroy()
self.world.contactListener_keepref = ContactDetector(self)
@@ -288,7 +294,10 @@ class LunarLander(gym.Env, EzPickle):
self.drawlist = [self.lander] + self.legs
return self.step(np.array([0, 0]) if self.continuous else 0)[0]
if not return_info:
return self.step(np.array([0, 0]) if self.continuous else 0)[0]
else:
return self.step(np.array([0, 0]) if self.continuous else 0)[0], {}
def _create_particle(self, mass, x, y, ttl):
p = self.world.CreateDynamicBody(

View File

@@ -162,12 +162,21 @@ class AcrobotEnv(core.Env):
self.action_space = spaces.Discrete(3)
self.state = None
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
super().reset(seed=seed)
self.state = self.np_random.uniform(low=-0.1, high=0.1, size=(4,)).astype(
np.float32
)
return self._get_ob()
if not return_info:
return self._get_ob()
else:
return self._get_ob(), {}
def step(self, a):
s = self.state

View File

@@ -165,11 +165,20 @@ class CartPoleEnv(gym.Env):
return np.array(self.state, dtype=np.float32), reward, done, {}
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed)
self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
self.steps_beyond_done = None
return np.array(self.state, dtype=np.float32)
if not return_info:
return np.array(self.state, dtype=np.float32)
else:
return np.array(self.state, dtype=np.float32), {}
def render(self, mode="human"):
screen_width = 600

View File

@@ -130,10 +130,19 @@ class Continuous_MountainCarEnv(gym.Env):
self.state = np.array([position, velocity], dtype=np.float32)
return self.state, reward, done, {}
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
super().reset(seed=seed)
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
return np.array(self.state, dtype=np.float32)
if not return_info:
return np.array(self.state, dtype=np.float32)
else:
return np.array(self.state, dtype=np.float32), {}
def _height(self, xs):
return np.sin(3 * xs) * 0.45 + 0.55

View File

@@ -93,10 +93,19 @@ class MountainCarEnv(gym.Env):
self.state = (position, velocity)
return np.array(self.state, dtype=np.float32), reward, done, {}
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed)
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
return np.array(self.state, dtype=np.float32)
if not return_info:
return np.array(self.state, dtype=np.float32)
else:
return np.array(self.state, dtype=np.float32), {}
def _height(self, xs):
return np.sin(3 * xs) * 0.45 + 0.55

View File

@@ -109,12 +109,21 @@ class PendulumEnv(gym.Env):
self.state = np.array([newth, newthdot])
return self._get_obs(), -costs, False, {}
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
super().reset(seed=seed)
high = np.array([np.pi, 1])
self.state = self.np_random.uniform(low=-high, high=high)
self.last_u = None
return self._get_obs()
if not return_info:
return self._get_obs()
else:
return self._get_obs(), {}
def _get_obs(self):
theta, thetadot = self.state

View File

@@ -103,11 +103,20 @@ class MujocoEnv(gym.Env):
# -----------------------------
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed)
self.sim.reset()
ob = self.reset_model()
return ob
if not return_info:
return ob
else:
return ob, {}
def set_state(self, qpos, qvel):
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)

View File

@@ -153,11 +153,19 @@ class BlackjackEnv(gym.Env):
def _get_obs(self):
return (sum_hand(self.player), self.dealer[0], usable_ace(self.player))
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
def reset(
self,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed)
self.dealer = draw_hand(self.np_random)
self.player = draw_hand(self.np_random)
return self._get_obs()
if not return_info:
return self._get_obs()
else:
return self._get_obs(), {}
def render(self, mode="human"):
player_sum, dealer_card_value, usable_ace = self._get_obs()

View File

@@ -102,11 +102,20 @@ class CliffWalkingEnv(Env):
self.lastaction = a
return (int(s), r, d, {"prob": p})
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
super().reset(seed=seed)
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
self.lastaction = None
return int(self.s)
if not return_info:
return int(self.s)
else:
return int(self.s), {"prob": 1}
def render(self, mode="human"):
outfile = StringIO() if mode == "ansi" else sys.stdout

View File

@@ -212,11 +212,21 @@ class FrozenLakeEnv(Env):
self.lastaction = a
return (int(s), r, d, {"prob": p})
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed)
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
self.lastaction = None
return int(self.s)
if not return_info:
return int(self.s)
else:
return int(self.s), {"prob": 1}
def render(self, mode="human"):
desc = self.desc.tolist()

View File

@@ -212,11 +212,20 @@ class TaxiEnv(Env):
self.lastaction = a
return (int(s), r, d, {"prob": p})
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed)
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
self.lastaction = None
return int(self.s)
if not return_info:
return int(self.s)
else:
return int(self.s), {"prob": 1}
def render(self, mode="human"):
outfile = StringIO() if mode == "ansi" else sys.stdout

View File

@@ -215,6 +215,7 @@ class AsyncVectorEnv(VectorEnv):
def reset_async(
self,
seed: Optional[Union[int, List[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
"""Send the calls to :obj:`reset` to each sub-environment.
@@ -248,6 +249,8 @@ class AsyncVectorEnv(VectorEnv):
single_kwargs = {}
if single_seed is not None:
single_kwargs["seed"] = single_seed
if return_info:
single_kwargs["return_info"] = return_info
if options is not None:
single_kwargs["options"] = options
@@ -255,7 +258,11 @@ class AsyncVectorEnv(VectorEnv):
self._state = AsyncState.WAITING_RESET
def reset_wait(
self, timeout=None, seed: Optional[int] = None, options: Optional[dict] = None
self,
timeout=None,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
"""
Parameters
@@ -270,6 +277,7 @@ class AsyncVectorEnv(VectorEnv):
-------
element of :attr:`~VectorEnv.observation_space`
A batch of observations from the vectorized environment.
infos : list of dicts containing metadata
Raises
------
@@ -300,12 +308,25 @@ class AsyncVectorEnv(VectorEnv):
self._raise_if_errors(successes)
self._state = AsyncState.DEFAULT
if not self.shared_memory:
self.observations = concatenate(
self.single_observation_space, results, self.observations
)
if return_info:
results, infos = zip(*results)
infos = list(infos)
return deepcopy(self.observations) if self.copy else self.observations
if not self.shared_memory:
self.observations = concatenate(
self.single_observation_space, results, self.observations
)
return (
deepcopy(self.observations) if self.copy else self.observations
), infos
else:
if not self.shared_memory:
self.observations = concatenate(
self.single_observation_space, results, self.observations
)
return deepcopy(self.observations) if self.copy else self.observations
def step_async(self, actions):
"""Send the calls to :obj:`step` to each sub-environment.
@@ -618,8 +639,13 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
while True:
command, data = pipe.recv()
if command == "reset":
observation = env.reset(**data)
pipe.send((observation, True))
if "return_info" in data and data["return_info"] == True:
observation, info = env.reset(**data)
pipe.send(((observation, info), True))
else:
observation = env.reset(**data)
pipe.send((observation, True))
elif command == "step":
observation, reward, done, info = env.step(data)
if done:
@@ -677,11 +703,18 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
while True:
command, data = pipe.recv()
if command == "reset":
observation = env.reset(**data)
write_to_shared_memory(
observation_space, index, observation, shared_memory
)
pipe.send((None, True))
if "return_info" in data and data["return_info"] == True:
observation, info = env.reset(**data)
write_to_shared_memory(
observation_space, index, observation, shared_memory
)
pipe.send(((None, info), True))
else:
observation = env.reset(**data)
write_to_shared_memory(
observation_space, index, observation, shared_memory
)
pipe.send((None, True))
elif command == "step":
observation, reward, done, info = env.step(data)
if done:

View File

@@ -89,6 +89,7 @@ class SyncVectorEnv(VectorEnv):
def reset_wait(
self,
seed: Optional[Union[int, List[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
if seed is None:
@@ -99,19 +100,34 @@ class SyncVectorEnv(VectorEnv):
self._dones[:] = False
observations = []
data_list = []
for env, single_seed in zip(self.envs, seed):
single_kwargs = {}
kwargs = {}
if single_seed is not None:
single_kwargs["seed"] = single_seed
kwargs["seed"] = single_seed
if options is not None:
single_kwargs["options"] = options
observation = env.reset(**single_kwargs)
observations.append(observation)
kwargs["options"] = options
if return_info == True:
kwargs["return_info"] = return_info
if not return_info:
observation = env.reset(**kwargs)
observations.append(observation)
else:
observation, data = env.reset(**kwargs)
observations.append(observation)
data_list.append(data)
self.observations = concatenate(
self.single_observation_space, observations, self.observations
)
return deepcopy(self.observations) if self.copy else self.observations
if not return_info:
return deepcopy(self.observations) if self.copy else self.observations
else:
return (
deepcopy(self.observations) if self.copy else self.observations
), data_list
def step_async(self, actions):
self._actions = iterate(self.action_space, actions)

View File

@@ -49,6 +49,7 @@ class VectorEnv(gym.Env):
def reset_async(
self,
seed: Optional[Union[int, List[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
pass
@@ -56,6 +57,7 @@ class VectorEnv(gym.Env):
def reset_wait(
self,
seed: Optional[Union[int, List[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
raise NotImplementedError()
@@ -64,6 +66,7 @@ class VectorEnv(gym.Env):
self,
*,
seed: Optional[Union[int, List[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
r"""Reset all sub-environments and return a batch of initial observations.
@@ -73,8 +76,8 @@ class VectorEnv(gym.Env):
element of :attr:`observation_space`
A batch of observations from the vectorized environment.
"""
self.reset_async(seed=seed, options=options)
return self.reset_wait(seed=seed, options=options)
self.reset_async(seed=seed, return_info=return_info, options=options)
return self.reset_wait(seed=seed, return_info=return_info, options=options)
def step_async(self, actions):
pass

View File

@@ -63,13 +63,26 @@ class NormalizeObservation(gym.core.Wrapper):
obs = self.normalize(np.array([obs]))[0]
return obs, rews, dones, infos
def reset(self, **kwargs):
obs = self.env.reset(**kwargs)
def reset(
self,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
obs = None
info = None
if not return_info:
obs = self.env.reset(seed=seed, options=options)
else:
obs, info = self.env.reset(seed=seed, return_info=True, options=options)
if self.is_vector_env:
obs = self.normalize(obs)
else:
obs = self.normalize(np.array([obs]))[0]
return obs
if not return_info:
return obs
else:
return obs, info
def normalize(self, obs):
self.obs_rms.update(obs)

View File

@@ -55,6 +55,23 @@ def test_env(spec):
env.close()
@pytest.mark.parametrize("spec", spec_list)
def test_reset_info(spec):
with pytest.warns(None) as warnings:
env = spec.make()
ob_space = env.observation_space
obs = env.reset()
assert ob_space.contains(obs)
obs = env.reset(return_info=False)
assert ob_space.contains(obs)
obs, info = env.reset(return_info=True)
assert ob_space.contains(obs)
assert isinstance(info, dict)
env.close()
# Run a longer rollout on some environments
def test_random_rollout():
for env in [envs.make("CartPole-v0"), envs.make("FrozenLake-v1")]:

View File

@@ -35,13 +35,22 @@ class UnknownSpacesEnv(core.Env):
on external resources), it is not encouraged.
"""
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
super().reset(seed=seed)
self.observation_space = spaces.Box(
low=0, high=255, shape=(64, 64, 3), dtype=np.uint8
)
self.action_space = spaces.Discrete(3)
return self.observation_space.sample() # Dummy observation
if not return_info:
return self.observation_space.sample() # Dummy observation
else:
return self.observation_space.sample(), {} # Dummy observation with info
def step(self, action):
observation = self.observation_space.sample() # Dummy observation

View File

@@ -40,6 +40,32 @@ def test_reset_async_vector_env(shared_memory):
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations = env.reset(return_info=False)
finally:
env.close()
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations, infos = env.reset(return_info=True)
finally:
env.close()
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
assert isinstance(infos, list)
assert all([isinstance(info, dict) for info in infos])
@pytest.mark.parametrize("shared_memory", [True, False])
@pytest.mark.parametrize("use_single_action_space", [True, False])

View File

@@ -31,6 +31,37 @@ def test_reset_sync_vector_env():
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
del observations
try:
env = SyncVectorEnv(env_fns)
observations = env.reset(return_info=False)
finally:
env.close()
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
del observations
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
env = SyncVectorEnv(env_fns)
observations, infos = env.reset(return_info=True)
finally:
env.close()
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
assert isinstance(infos, list)
assert all([isinstance(info, dict) for info in infos])
@pytest.mark.parametrize("use_single_action_space", [True, False])
def test_step_sync_vector_env(use_single_action_space):

View File

@@ -23,10 +23,19 @@ class DummyRewardEnv(gym.Env):
self.t += 1
return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {}
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
def reset(
self,
*,
seed: Optional[int] = None,
return_info: Optional[bool] = False,
options: Optional[dict] = None
):
super().reset(seed=seed)
self.t = self.return_reward_idx
return np.array([self.t])
if not return_info:
return np.array([self.t])
else:
return np.array([self.t]), {}
def make_env(return_reward_idx):
@@ -47,6 +56,20 @@ def test_normalize_observation():
assert_almost_equal(env.obs_rms.mean, 1.0, decimal=4)
def test_normalize_reset_info():
env = DummyRewardEnv(return_reward_idx=0)
env = NormalizeObservation(env)
obs = env.reset()
assert isinstance(obs, np.ndarray)
del obs
obs = env.reset(return_info=False)
assert isinstance(obs, np.ndarray)
del obs
obs, info = env.reset(return_info=True)
assert isinstance(obs, np.ndarray)
assert isinstance(info, dict)
def test_normalize_return():
env = DummyRewardEnv(return_reward_idx=0)
env = NormalizeReward(env)

View File

@@ -0,0 +1,21 @@
import pytest
import numpy as np
import gym
from gym.wrappers import OrderEnforcing
def test_order_enforcing_reset_info():
env = gym.make("CartPole-v1")
env = OrderEnforcing(env)
ob_space = env.observation_space
obs = env.reset()
assert ob_space.contains(obs)
del obs
obs = env.reset(return_info=False)
assert ob_space.contains(obs)
del obs
obs, info = env.reset(return_info=True)
assert ob_space.contains(obs)
assert isinstance(info, dict)

View File

@@ -1,5 +1,7 @@
import pytest
import numpy as np
import gym
from gym.wrappers import RecordEpisodeStatistics
@@ -24,6 +26,18 @@ def test_record_episode_statistics(env_id, deque_size):
assert len(env.length_queue) == deque_size
def test_record_episode_statistics_reset_info():
env = gym.make("CartPole-v1")
env = RecordEpisodeStatistics(env)
ob_space = env.observation_space
obs = env.reset()
assert ob_space.contains(obs)
del obs
obs, info = env.reset(return_info=True)
assert ob_space.contains(obs)
assert isinstance(info, dict)
@pytest.mark.parametrize(
("num_envs", "asynchronous"), [(1, False), (1, True), (4, False), (4, True)]
)

View File

@@ -1,7 +1,9 @@
import pytest
import os
import shutil
import numpy as np
import gym
from gym.wrappers import (
RecordEpisodeStatistics,
@@ -29,6 +31,36 @@ def test_record_video_using_default_trigger():
shutil.rmtree("videos")
def test_record_video_reset_return_info():
env = gym.make("CartPole-v1")
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
ob_space = env.observation_space
obs, info = env.reset(return_info=True)
env.close()
assert os.path.isdir("videos")
shutil.rmtree("videos")
assert ob_space.contains(obs)
assert isinstance(info, dict)
env = gym.make("CartPole-v1")
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
ob_space = env.observation_space
obs = env.reset(return_info=False)
env.close()
assert os.path.isdir("videos")
shutil.rmtree("videos")
assert ob_space.contains(obs)
env = gym.make("CartPole-v1")
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
ob_space = env.observation_space
obs = env.reset()
env.close()
assert os.path.isdir("videos")
shutil.rmtree("videos")
assert ob_space.contains(obs)
def test_record_video_step_trigger():
env = gym.make("CartPole-v1")
env._max_episode_steps = 20

View File

@@ -0,0 +1,21 @@
import pytest
import numpy as np
import gym
from gym.wrappers import TimeLimit
def test_time_limit_reset_info():
env = gym.make("CartPole-v1")
env = TimeLimit(env)
ob_space = env.observation_space
obs = env.reset()
assert ob_space.contains(obs)
del obs
obs = env.reset(return_info=False)
assert ob_space.contains(obs)
del obs
obs, info = env.reset(return_info=True)
assert ob_space.contains(obs)
assert isinstance(info, dict)