Adding return_info argument to reset to allow for optional info dict as a second return value (#2546)

* initial draft of optional info dict in reset function, implemented for cartpole, tests seem to be passing

* merged core.py

* updated return type annotation for reset function in core.py

* optional metadata with return_info from reset added for all first party environments, with corresponding tests. Incomplete implementation for wrappers and vector wrappers

* removed Optional type for return_info arguments

* added tests for return_info to normalize wrapper and sync_vector_env

* autoformatted using black

* added optional reset metadata tests to several wrappers

* added return_info capability to async_vector_env.py and test to verify functionality

* added optional return_info test for record_video.py

* removed tests for mujoco environments

* autoformatted

* improved test coverage for optional reset return_info

* re-removed unit test envs accidentally reintroduced in merge

* removed unnecessary import

* changes based on code-review

* small fix to core wrapper typing and autoformatted record_epsisode_stats

* small change to pass flake8 style
This commit is contained in:
John Balis
2022-02-06 17:28:27 -06:00
committed by GitHub
parent 62e52727d5
commit 15049e22d7
27 changed files with 441 additions and 62 deletions

View File

@@ -1,8 +1,7 @@
from __future__ import annotations from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
from typing import TypeVar, Generic, Tuple, SupportsFloat from typing import TypeVar, Generic, Tuple, Union, Optional, SupportsFloat
from typing import Optional
import gym import gym
from gym import error, spaces from gym import error, spaces
@@ -71,8 +70,12 @@ class Env(Generic[ObsType, ActType]):
@abstractmethod @abstractmethod
def reset( def reset(
self, *, seed: Optional[int] = None, options: Optional[dict] = None self,
) -> ObsType: *,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
) -> Union[ObsType, tuple[ObsType, dict]]:
"""Resets the environment to an initial state and returns an initial """Resets the environment to an initial state and returns an initial
observation. observation.
@@ -84,6 +87,7 @@ class Env(Generic[ObsType, ActType]):
Returns: Returns:
observation (object): the initial observation. observation (object): the initial observation.
info (optional dictionary): a dictionary containing extra information, this is only returned if return_info is set to true
""" """
# Initialize the RNG if it's the first reset, or if the seed is manually passed # Initialize the RNG if it's the first reset, or if the seed is manually passed
if seed is not None or self.np_random is None: if seed is not None or self.np_random is None:
@@ -262,7 +266,7 @@ class Wrapper(Env[ObsType, ActType]):
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]: def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
return self.env.step(action) return self.env.step(action)
def reset(self, **kwargs) -> ObsType: def reset(self, **kwargs) -> Union[ObsType, tuple[ObsType, dict]]:
return self.env.reset(**kwargs) return self.env.reset(**kwargs)
def render(self, mode="human", **kwargs): def render(self, mode="human", **kwargs):

View File

@@ -350,7 +350,13 @@ class BipedalWalker(gym.Env, EzPickle):
x2 = max(p[0] for p in poly) x2 = max(p[0] for p in poly)
self.cloud_poly.append((poly, x1, x2)) self.cloud_poly.append((poly, x1, x2))
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed) super().reset(seed=seed)
self._destroy() self._destroy()
self.world.contactListener_bug_workaround = ContactDetector(self) self.world.contactListener_bug_workaround = ContactDetector(self)
@@ -436,8 +442,10 @@ class BipedalWalker(gym.Env, EzPickle):
return fraction return fraction
self.lidar = [LidarCallback() for _ in range(10)] self.lidar = [LidarCallback() for _ in range(10)]
if not return_info:
return self.step(np.array([0, 0, 0, 0]))[0] return self.step(np.array([0, 0, 0, 0]))[0]
else:
return self.step(np.array([0, 0, 0, 0]))[0], {}
def step(self, action): def step(self, action):
# self.hull.ApplyForceToCenter((0, 20), True) -- Uncomment this to receive a bit of stability help # self.hull.ApplyForceToCenter((0, 20), True) -- Uncomment this to receive a bit of stability help

View File

@@ -374,7 +374,13 @@ class CarRacing(gym.Env, EzPickle):
self.track = track self.track = track
return True return True
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed) super().reset(seed=seed)
self._destroy() self._destroy()
self.reward = 0.0 self.reward = 0.0
@@ -395,7 +401,10 @@ class CarRacing(gym.Env, EzPickle):
) )
self.car = Car(self.world, *self.track[0][1:4]) self.car = Car(self.world, *self.track[0][1:4])
if not return_info:
return self.step(None)[0] return self.step(None)[0]
else:
return self.step(None)[0], {}
def step(self, action): def step(self, action):
if action is not None: if action is not None:

View File

@@ -183,7 +183,13 @@ class LunarLander(gym.Env, EzPickle):
self.world.DestroyBody(self.legs[0]) self.world.DestroyBody(self.legs[0])
self.world.DestroyBody(self.legs[1]) self.world.DestroyBody(self.legs[1])
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed) super().reset(seed=seed)
self._destroy() self._destroy()
self.world.contactListener_keepref = ContactDetector(self) self.world.contactListener_keepref = ContactDetector(self)
@@ -288,7 +294,10 @@ class LunarLander(gym.Env, EzPickle):
self.drawlist = [self.lander] + self.legs self.drawlist = [self.lander] + self.legs
if not return_info:
return self.step(np.array([0, 0]) if self.continuous else 0)[0] return self.step(np.array([0, 0]) if self.continuous else 0)[0]
else:
return self.step(np.array([0, 0]) if self.continuous else 0)[0], {}
def _create_particle(self, mass, x, y, ttl): def _create_particle(self, mass, x, y, ttl):
p = self.world.CreateDynamicBody( p = self.world.CreateDynamicBody(

View File

@@ -162,12 +162,21 @@ class AcrobotEnv(core.Env):
self.action_space = spaces.Discrete(3) self.action_space = spaces.Discrete(3)
self.state = None self.state = None
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
super().reset(seed=seed) super().reset(seed=seed)
self.state = self.np_random.uniform(low=-0.1, high=0.1, size=(4,)).astype( self.state = self.np_random.uniform(low=-0.1, high=0.1, size=(4,)).astype(
np.float32 np.float32
) )
if not return_info:
return self._get_ob() return self._get_ob()
else:
return self._get_ob(), {}
def step(self, a): def step(self, a):
s = self.state s = self.state

View File

@@ -165,11 +165,20 @@ class CartPoleEnv(gym.Env):
return np.array(self.state, dtype=np.float32), reward, done, {} return np.array(self.state, dtype=np.float32), reward, done, {}
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed) super().reset(seed=seed)
self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,)) self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
self.steps_beyond_done = None self.steps_beyond_done = None
if not return_info:
return np.array(self.state, dtype=np.float32) return np.array(self.state, dtype=np.float32)
else:
return np.array(self.state, dtype=np.float32), {}
def render(self, mode="human"): def render(self, mode="human"):
screen_width = 600 screen_width = 600

View File

@@ -130,10 +130,19 @@ class Continuous_MountainCarEnv(gym.Env):
self.state = np.array([position, velocity], dtype=np.float32) self.state = np.array([position, velocity], dtype=np.float32)
return self.state, reward, done, {} return self.state, reward, done, {}
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
super().reset(seed=seed) super().reset(seed=seed)
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0]) self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
if not return_info:
return np.array(self.state, dtype=np.float32) return np.array(self.state, dtype=np.float32)
else:
return np.array(self.state, dtype=np.float32), {}
def _height(self, xs): def _height(self, xs):
return np.sin(3 * xs) * 0.45 + 0.55 return np.sin(3 * xs) * 0.45 + 0.55

View File

@@ -93,10 +93,19 @@ class MountainCarEnv(gym.Env):
self.state = (position, velocity) self.state = (position, velocity)
return np.array(self.state, dtype=np.float32), reward, done, {} return np.array(self.state, dtype=np.float32), reward, done, {}
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed) super().reset(seed=seed)
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0]) self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
if not return_info:
return np.array(self.state, dtype=np.float32) return np.array(self.state, dtype=np.float32)
else:
return np.array(self.state, dtype=np.float32), {}
def _height(self, xs): def _height(self, xs):
return np.sin(3 * xs) * 0.45 + 0.55 return np.sin(3 * xs) * 0.45 + 0.55

View File

@@ -109,12 +109,21 @@ class PendulumEnv(gym.Env):
self.state = np.array([newth, newthdot]) self.state = np.array([newth, newthdot])
return self._get_obs(), -costs, False, {} return self._get_obs(), -costs, False, {}
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
super().reset(seed=seed) super().reset(seed=seed)
high = np.array([np.pi, 1]) high = np.array([np.pi, 1])
self.state = self.np_random.uniform(low=-high, high=high) self.state = self.np_random.uniform(low=-high, high=high)
self.last_u = None self.last_u = None
if not return_info:
return self._get_obs() return self._get_obs()
else:
return self._get_obs(), {}
def _get_obs(self): def _get_obs(self):
theta, thetadot = self.state theta, thetadot = self.state

View File

@@ -103,11 +103,20 @@ class MujocoEnv(gym.Env):
# ----------------------------- # -----------------------------
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed) super().reset(seed=seed)
self.sim.reset() self.sim.reset()
ob = self.reset_model() ob = self.reset_model()
if not return_info:
return ob return ob
else:
return ob, {}
def set_state(self, qpos, qvel): def set_state(self, qpos, qvel):
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,) assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)

View File

@@ -153,11 +153,19 @@ class BlackjackEnv(gym.Env):
def _get_obs(self): def _get_obs(self):
return (sum_hand(self.player), self.dealer[0], usable_ace(self.player)) return (sum_hand(self.player), self.dealer[0], usable_ace(self.player))
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(
self,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed) super().reset(seed=seed)
self.dealer = draw_hand(self.np_random) self.dealer = draw_hand(self.np_random)
self.player = draw_hand(self.np_random) self.player = draw_hand(self.np_random)
if not return_info:
return self._get_obs() return self._get_obs()
else:
return self._get_obs(), {}
def render(self, mode="human"): def render(self, mode="human"):
player_sum, dealer_card_value, usable_ace = self._get_obs() player_sum, dealer_card_value, usable_ace = self._get_obs()

View File

@@ -102,11 +102,20 @@ class CliffWalkingEnv(Env):
self.lastaction = a self.lastaction = a
return (int(s), r, d, {"prob": p}) return (int(s), r, d, {"prob": p})
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
super().reset(seed=seed) super().reset(seed=seed)
self.s = categorical_sample(self.initial_state_distrib, self.np_random) self.s = categorical_sample(self.initial_state_distrib, self.np_random)
self.lastaction = None self.lastaction = None
if not return_info:
return int(self.s) return int(self.s)
else:
return int(self.s), {"prob": 1}
def render(self, mode="human"): def render(self, mode="human"):
outfile = StringIO() if mode == "ansi" else sys.stdout outfile = StringIO() if mode == "ansi" else sys.stdout

View File

@@ -212,11 +212,21 @@ class FrozenLakeEnv(Env):
self.lastaction = a self.lastaction = a
return (int(s), r, d, {"prob": p}) return (int(s), r, d, {"prob": p})
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed) super().reset(seed=seed)
self.s = categorical_sample(self.initial_state_distrib, self.np_random) self.s = categorical_sample(self.initial_state_distrib, self.np_random)
self.lastaction = None self.lastaction = None
if not return_info:
return int(self.s) return int(self.s)
else:
return int(self.s), {"prob": 1}
def render(self, mode="human"): def render(self, mode="human"):
desc = self.desc.tolist() desc = self.desc.tolist()

View File

@@ -212,11 +212,20 @@ class TaxiEnv(Env):
self.lastaction = a self.lastaction = a
return (int(s), r, d, {"prob": p}) return (int(s), r, d, {"prob": p})
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
super().reset(seed=seed) super().reset(seed=seed)
self.s = categorical_sample(self.initial_state_distrib, self.np_random) self.s = categorical_sample(self.initial_state_distrib, self.np_random)
self.lastaction = None self.lastaction = None
if not return_info:
return int(self.s) return int(self.s)
else:
return int(self.s), {"prob": 1}
def render(self, mode="human"): def render(self, mode="human"):
outfile = StringIO() if mode == "ansi" else sys.stdout outfile = StringIO() if mode == "ansi" else sys.stdout

View File

@@ -215,6 +215,7 @@ class AsyncVectorEnv(VectorEnv):
def reset_async( def reset_async(
self, self,
seed: Optional[Union[int, List[int]]] = None, seed: Optional[Union[int, List[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
"""Send the calls to :obj:`reset` to each sub-environment. """Send the calls to :obj:`reset` to each sub-environment.
@@ -248,6 +249,8 @@ class AsyncVectorEnv(VectorEnv):
single_kwargs = {} single_kwargs = {}
if single_seed is not None: if single_seed is not None:
single_kwargs["seed"] = single_seed single_kwargs["seed"] = single_seed
if return_info:
single_kwargs["return_info"] = return_info
if options is not None: if options is not None:
single_kwargs["options"] = options single_kwargs["options"] = options
@@ -255,7 +258,11 @@ class AsyncVectorEnv(VectorEnv):
self._state = AsyncState.WAITING_RESET self._state = AsyncState.WAITING_RESET
def reset_wait( def reset_wait(
self, timeout=None, seed: Optional[int] = None, options: Optional[dict] = None self,
timeout=None,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
): ):
""" """
Parameters Parameters
@@ -270,6 +277,7 @@ class AsyncVectorEnv(VectorEnv):
------- -------
element of :attr:`~VectorEnv.observation_space` element of :attr:`~VectorEnv.observation_space`
A batch of observations from the vectorized environment. A batch of observations from the vectorized environment.
infos : list of dicts containing metadata
Raises Raises
------ ------
@@ -300,6 +308,19 @@ class AsyncVectorEnv(VectorEnv):
self._raise_if_errors(successes) self._raise_if_errors(successes)
self._state = AsyncState.DEFAULT self._state = AsyncState.DEFAULT
if return_info:
results, infos = zip(*results)
infos = list(infos)
if not self.shared_memory:
self.observations = concatenate(
self.single_observation_space, results, self.observations
)
return (
deepcopy(self.observations) if self.copy else self.observations
), infos
else:
if not self.shared_memory: if not self.shared_memory:
self.observations = concatenate( self.observations = concatenate(
self.single_observation_space, results, self.observations self.single_observation_space, results, self.observations
@@ -618,8 +639,13 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
while True: while True:
command, data = pipe.recv() command, data = pipe.recv()
if command == "reset": if command == "reset":
if "return_info" in data and data["return_info"] == True:
observation, info = env.reset(**data)
pipe.send(((observation, info), True))
else:
observation = env.reset(**data) observation = env.reset(**data)
pipe.send((observation, True)) pipe.send((observation, True))
elif command == "step": elif command == "step":
observation, reward, done, info = env.step(data) observation, reward, done, info = env.step(data)
if done: if done:
@@ -677,6 +703,13 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
while True: while True:
command, data = pipe.recv() command, data = pipe.recv()
if command == "reset": if command == "reset":
if "return_info" in data and data["return_info"] == True:
observation, info = env.reset(**data)
write_to_shared_memory(
observation_space, index, observation, shared_memory
)
pipe.send(((None, info), True))
else:
observation = env.reset(**data) observation = env.reset(**data)
write_to_shared_memory( write_to_shared_memory(
observation_space, index, observation, shared_memory observation_space, index, observation, shared_memory

View File

@@ -89,6 +89,7 @@ class SyncVectorEnv(VectorEnv):
def reset_wait( def reset_wait(
self, self,
seed: Optional[Union[int, List[int]]] = None, seed: Optional[Union[int, List[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
if seed is None: if seed is None:
@@ -99,19 +100,34 @@ class SyncVectorEnv(VectorEnv):
self._dones[:] = False self._dones[:] = False
observations = [] observations = []
data_list = []
for env, single_seed in zip(self.envs, seed): for env, single_seed in zip(self.envs, seed):
single_kwargs = {}
kwargs = {}
if single_seed is not None: if single_seed is not None:
single_kwargs["seed"] = single_seed kwargs["seed"] = single_seed
if options is not None: if options is not None:
single_kwargs["options"] = options kwargs["options"] = options
observation = env.reset(**single_kwargs) if return_info == True:
kwargs["return_info"] = return_info
if not return_info:
observation = env.reset(**kwargs)
observations.append(observation) observations.append(observation)
else:
observation, data = env.reset(**kwargs)
observations.append(observation)
data_list.append(data)
self.observations = concatenate( self.observations = concatenate(
self.single_observation_space, observations, self.observations self.single_observation_space, observations, self.observations
) )
if not return_info:
return deepcopy(self.observations) if self.copy else self.observations return deepcopy(self.observations) if self.copy else self.observations
else:
return (
deepcopy(self.observations) if self.copy else self.observations
), data_list
def step_async(self, actions): def step_async(self, actions):
self._actions = iterate(self.action_space, actions) self._actions = iterate(self.action_space, actions)

View File

@@ -49,6 +49,7 @@ class VectorEnv(gym.Env):
def reset_async( def reset_async(
self, self,
seed: Optional[Union[int, List[int]]] = None, seed: Optional[Union[int, List[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
pass pass
@@ -56,6 +57,7 @@ class VectorEnv(gym.Env):
def reset_wait( def reset_wait(
self, self,
seed: Optional[Union[int, List[int]]] = None, seed: Optional[Union[int, List[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
raise NotImplementedError() raise NotImplementedError()
@@ -64,6 +66,7 @@ class VectorEnv(gym.Env):
self, self,
*, *,
seed: Optional[Union[int, List[int]]] = None, seed: Optional[Union[int, List[int]]] = None,
return_info: bool = False,
options: Optional[dict] = None, options: Optional[dict] = None,
): ):
r"""Reset all sub-environments and return a batch of initial observations. r"""Reset all sub-environments and return a batch of initial observations.
@@ -73,8 +76,8 @@ class VectorEnv(gym.Env):
element of :attr:`observation_space` element of :attr:`observation_space`
A batch of observations from the vectorized environment. A batch of observations from the vectorized environment.
""" """
self.reset_async(seed=seed, options=options) self.reset_async(seed=seed, return_info=return_info, options=options)
return self.reset_wait(seed=seed, options=options) return self.reset_wait(seed=seed, return_info=return_info, options=options)
def step_async(self, actions): def step_async(self, actions):
pass pass

View File

@@ -63,13 +63,26 @@ class NormalizeObservation(gym.core.Wrapper):
obs = self.normalize(np.array([obs]))[0] obs = self.normalize(np.array([obs]))[0]
return obs, rews, dones, infos return obs, rews, dones, infos
def reset(self, **kwargs): def reset(
obs = self.env.reset(**kwargs) self,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
):
obs = None
info = None
if not return_info:
obs = self.env.reset(seed=seed, options=options)
else:
obs, info = self.env.reset(seed=seed, return_info=True, options=options)
if self.is_vector_env: if self.is_vector_env:
obs = self.normalize(obs) obs = self.normalize(obs)
else: else:
obs = self.normalize(np.array([obs]))[0] obs = self.normalize(np.array([obs]))[0]
if not return_info:
return obs return obs
else:
return obs, info
def normalize(self, obs): def normalize(self, obs):
self.obs_rms.update(obs) self.obs_rms.update(obs)

View File

@@ -55,6 +55,23 @@ def test_env(spec):
env.close() env.close()
@pytest.mark.parametrize("spec", spec_list)
def test_reset_info(spec):
with pytest.warns(None) as warnings:
env = spec.make()
ob_space = env.observation_space
obs = env.reset()
assert ob_space.contains(obs)
obs = env.reset(return_info=False)
assert ob_space.contains(obs)
obs, info = env.reset(return_info=True)
assert ob_space.contains(obs)
assert isinstance(info, dict)
env.close()
# Run a longer rollout on some environments # Run a longer rollout on some environments
def test_random_rollout(): def test_random_rollout():
for env in [envs.make("CartPole-v0"), envs.make("FrozenLake-v1")]: for env in [envs.make("CartPole-v0"), envs.make("FrozenLake-v1")]:

View File

@@ -35,13 +35,22 @@ class UnknownSpacesEnv(core.Env):
on external resources), it is not encouraged. on external resources), it is not encouraged.
""" """
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None
):
super().reset(seed=seed) super().reset(seed=seed)
self.observation_space = spaces.Box( self.observation_space = spaces.Box(
low=0, high=255, shape=(64, 64, 3), dtype=np.uint8 low=0, high=255, shape=(64, 64, 3), dtype=np.uint8
) )
self.action_space = spaces.Discrete(3) self.action_space = spaces.Discrete(3)
if not return_info:
return self.observation_space.sample() # Dummy observation return self.observation_space.sample() # Dummy observation
else:
return self.observation_space.sample(), {} # Dummy observation with info
def step(self, action): def step(self, action):
observation = self.observation_space.sample() # Dummy observation observation = self.observation_space.sample() # Dummy observation

View File

@@ -40,6 +40,32 @@ def test_reset_async_vector_env(shared_memory):
assert observations.shape == (8,) + env.single_observation_space.shape assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape assert observations.shape == env.observation_space.shape
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations = env.reset(return_info=False)
finally:
env.close()
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations, infos = env.reset(return_info=True)
finally:
env.close()
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
assert isinstance(infos, list)
assert all([isinstance(info, dict) for info in infos])
@pytest.mark.parametrize("shared_memory", [True, False]) @pytest.mark.parametrize("shared_memory", [True, False])
@pytest.mark.parametrize("use_single_action_space", [True, False]) @pytest.mark.parametrize("use_single_action_space", [True, False])

View File

@@ -31,6 +31,37 @@ def test_reset_sync_vector_env():
assert observations.shape == (8,) + env.single_observation_space.shape assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape assert observations.shape == env.observation_space.shape
del observations
try:
env = SyncVectorEnv(env_fns)
observations = env.reset(return_info=False)
finally:
env.close()
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
del observations
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
env = SyncVectorEnv(env_fns)
observations, infos = env.reset(return_info=True)
finally:
env.close()
assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray)
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
assert isinstance(infos, list)
assert all([isinstance(info, dict) for info in infos])
@pytest.mark.parametrize("use_single_action_space", [True, False]) @pytest.mark.parametrize("use_single_action_space", [True, False])
def test_step_sync_vector_env(use_single_action_space): def test_step_sync_vector_env(use_single_action_space):

View File

@@ -23,10 +23,19 @@ class DummyRewardEnv(gym.Env):
self.t += 1 self.t += 1
return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {} return np.array([self.t]), self.t, self.t == len(self.returned_rewards), {}
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(
self,
*,
seed: Optional[int] = None,
return_info: Optional[bool] = False,
options: Optional[dict] = None
):
super().reset(seed=seed) super().reset(seed=seed)
self.t = self.return_reward_idx self.t = self.return_reward_idx
if not return_info:
return np.array([self.t]) return np.array([self.t])
else:
return np.array([self.t]), {}
def make_env(return_reward_idx): def make_env(return_reward_idx):
@@ -47,6 +56,20 @@ def test_normalize_observation():
assert_almost_equal(env.obs_rms.mean, 1.0, decimal=4) assert_almost_equal(env.obs_rms.mean, 1.0, decimal=4)
def test_normalize_reset_info():
env = DummyRewardEnv(return_reward_idx=0)
env = NormalizeObservation(env)
obs = env.reset()
assert isinstance(obs, np.ndarray)
del obs
obs = env.reset(return_info=False)
assert isinstance(obs, np.ndarray)
del obs
obs, info = env.reset(return_info=True)
assert isinstance(obs, np.ndarray)
assert isinstance(info, dict)
def test_normalize_return(): def test_normalize_return():
env = DummyRewardEnv(return_reward_idx=0) env = DummyRewardEnv(return_reward_idx=0)
env = NormalizeReward(env) env = NormalizeReward(env)

View File

@@ -0,0 +1,21 @@
import pytest
import numpy as np
import gym
from gym.wrappers import OrderEnforcing
def test_order_enforcing_reset_info():
env = gym.make("CartPole-v1")
env = OrderEnforcing(env)
ob_space = env.observation_space
obs = env.reset()
assert ob_space.contains(obs)
del obs
obs = env.reset(return_info=False)
assert ob_space.contains(obs)
del obs
obs, info = env.reset(return_info=True)
assert ob_space.contains(obs)
assert isinstance(info, dict)

View File

@@ -1,5 +1,7 @@
import pytest import pytest
import numpy as np
import gym import gym
from gym.wrappers import RecordEpisodeStatistics from gym.wrappers import RecordEpisodeStatistics
@@ -24,6 +26,18 @@ def test_record_episode_statistics(env_id, deque_size):
assert len(env.length_queue) == deque_size assert len(env.length_queue) == deque_size
def test_record_episode_statistics_reset_info():
env = gym.make("CartPole-v1")
env = RecordEpisodeStatistics(env)
ob_space = env.observation_space
obs = env.reset()
assert ob_space.contains(obs)
del obs
obs, info = env.reset(return_info=True)
assert ob_space.contains(obs)
assert isinstance(info, dict)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("num_envs", "asynchronous"), [(1, False), (1, True), (4, False), (4, True)] ("num_envs", "asynchronous"), [(1, False), (1, True), (4, False), (4, True)]
) )

View File

@@ -1,7 +1,9 @@
import pytest import pytest
import os import os
import shutil import shutil
import numpy as np
import gym import gym
from gym.wrappers import ( from gym.wrappers import (
RecordEpisodeStatistics, RecordEpisodeStatistics,
@@ -29,6 +31,36 @@ def test_record_video_using_default_trigger():
shutil.rmtree("videos") shutil.rmtree("videos")
def test_record_video_reset_return_info():
env = gym.make("CartPole-v1")
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
ob_space = env.observation_space
obs, info = env.reset(return_info=True)
env.close()
assert os.path.isdir("videos")
shutil.rmtree("videos")
assert ob_space.contains(obs)
assert isinstance(info, dict)
env = gym.make("CartPole-v1")
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
ob_space = env.observation_space
obs = env.reset(return_info=False)
env.close()
assert os.path.isdir("videos")
shutil.rmtree("videos")
assert ob_space.contains(obs)
env = gym.make("CartPole-v1")
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
ob_space = env.observation_space
obs = env.reset()
env.close()
assert os.path.isdir("videos")
shutil.rmtree("videos")
assert ob_space.contains(obs)
def test_record_video_step_trigger(): def test_record_video_step_trigger():
env = gym.make("CartPole-v1") env = gym.make("CartPole-v1")
env._max_episode_steps = 20 env._max_episode_steps = 20

View File

@@ -0,0 +1,21 @@
import pytest
import numpy as np
import gym
from gym.wrappers import TimeLimit
def test_time_limit_reset_info():
env = gym.make("CartPole-v1")
env = TimeLimit(env)
ob_space = env.observation_space
obs = env.reset()
assert ob_space.contains(obs)
del obs
obs = env.reset(return_info=False)
assert ob_space.contains(obs)
del obs
obs, info = env.reset(return_info=True)
assert ob_space.contains(obs)
assert isinstance(info, dict)