New info API for vectorized environments #2657 (#2773)

* WIP refactor info API sync vector.

* Add missing untracked file.

* Add info strategy to reset_wait.

* Add interface and docstring.

* info with strategy pattern on async vector env.

* Add default to async vecenv.

* episode statistics for asyncvecnev.

* Add tests info strategy format.

* Add info strategy to reset_wait.

* refactor and cleanup.

* Code cleanup. Add tests.

* Add tests for video recording with new info format.

* fix test case.

* fix camelcase.

* rename enum.

* update tests, docstrings, cleanup.

* Changes brax strategy to numpy. add_strategy method in StrategyFactory. Add tests.

* fix docstring and logging format.

* Set Brax info format as default. Remove classic info format. Update tests.

* breaking the wrong loop.

* WIP: wrapper.

* Add wrapper for brax to classic info.

* WIP: wrapper with nested RecordEpisodeStatistic.

* Add tests. Refactor docstrings. Cleanup.

* cleanup.

* patch conflicts.

* rebase and conflicts.

* new pre-commit conventions.

* docstring.

* renaming.

* incorporate info_processor in vecEnv.

* renaming. Create info dict only if needed.

* remove all brax references. update docstring. Update duplicate test.

* reviews.

* pre-commit.

* reviews.

* docstring.

* cleanup blank lines.

* add support for numpy dtypes.

* docstring fix.

* formatting.

* naming.

* assert correct info from wrappers chaining. Test correct wrappers chaining. naming.

* simplify episode_statistics.

* change args orer.

* update tests.

* wip: refactor episode_statistics.

* Add test for add_vecore_episode_statistics.
This commit is contained in:
Gianluca De Cola
2022-05-24 16:36:35 +02:00
committed by GitHub
parent bbf8f5a467
commit 49d8299a1e
13 changed files with 428 additions and 42 deletions

View File

@@ -271,8 +271,10 @@ class AsyncVectorEnv(VectorEnv):
self._state = AsyncState.DEFAULT self._state = AsyncState.DEFAULT
if return_info: if return_info:
results, infos = zip(*results) infos = {}
infos = list(infos) results, info_data = zip(*results)
for i, info in enumerate(info_data):
infos = self._add_info(infos, info, i)
if not self.shared_memory: if not self.shared_memory:
self.observations = concatenate( self.observations = concatenate(
@@ -344,10 +346,20 @@ class AsyncVectorEnv(VectorEnv):
f"The call to `step_wait` has timed out after {timeout} second(s)." f"The call to `step_wait` has timed out after {timeout} second(s)."
) )
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) observations_list, rewards, dones, infos = [], [], [], {}
successes = []
for i, pipe in enumerate(self.parent_pipes):
result, success = pipe.recv()
obs, rew, done, info = result
successes.append(success)
observations_list.append(obs)
rewards.append(rew)
dones.append(done)
infos = self._add_info(infos, info, i)
self._raise_if_errors(successes) self._raise_if_errors(successes)
self._state = AsyncState.DEFAULT self._state = AsyncState.DEFAULT
observations_list, rewards, dones, infos = zip(*results)
if not self.shared_memory: if not self.shared_memory:
self.observations = concatenate( self.observations = concatenate(

View File

@@ -108,8 +108,8 @@ class SyncVectorEnv(VectorEnv):
self._dones[:] = False self._dones[:] = False
observations = [] observations = []
data_list = [] infos = {}
for env, single_seed in zip(self.envs, seed): for i, (env, single_seed) in enumerate(zip(self.envs, seed)):
kwargs = {} kwargs = {}
if single_seed is not None: if single_seed is not None:
@@ -123,9 +123,9 @@ class SyncVectorEnv(VectorEnv):
observation = env.reset(**kwargs) observation = env.reset(**kwargs)
observations.append(observation) observations.append(observation)
else: else:
observation, data = env.reset(**kwargs) observation, info = env.reset(**kwargs)
observations.append(observation) observations.append(observation)
data_list.append(data) infos = self._add_info(infos, info, i)
self.observations = concatenate( self.observations = concatenate(
self.single_observation_space, observations, self.observations self.single_observation_space, observations, self.observations
@@ -135,7 +135,7 @@ class SyncVectorEnv(VectorEnv):
else: else:
return ( return (
deepcopy(self.observations) if self.copy else self.observations deepcopy(self.observations) if self.copy else self.observations
), data_list ), infos
def step_async(self, actions): def step_async(self, actions):
"""Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version.""" """Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version."""
@@ -147,14 +147,14 @@ class SyncVectorEnv(VectorEnv):
Returns: Returns:
The batched environment step results The batched environment step results
""" """
observations, infos = [], [] observations, infos = [], {}
for i, (env, action) in enumerate(zip(self.envs, self._actions)): for i, (env, action) in enumerate(zip(self.envs, self._actions)):
observation, self._rewards[i], self._dones[i], info = env.step(action) observation, self._rewards[i], self._dones[i], info = env.step(action)
if self._dones[i]: if self._dones[i]:
info["terminal_observation"] = observation info["terminal_observation"] = observation
observation = env.reset() observation = env.reset()
observations.append(observation) observations.append(observation)
infos.append(info) infos = self._add_info(infos, info, i)
self.observations = concatenate( self.observations = concatenate(
self.single_observation_space, observations, self.observations self.single_observation_space, observations, self.observations
) )

View File

@@ -3,6 +3,8 @@ from __future__ import annotations
from typing import Any, Optional, Union from typing import Any, Optional, Union
import numpy as np
import gym import gym
from gym.logger import deprecation from gym.logger import deprecation
from gym.vector.utils.spaces import batch_space from gym.vector.utils.spaces import batch_space
@@ -201,6 +203,58 @@ class VectorEnv(gym.Env):
"Please use `env.reset(seed=seed) instead in VectorEnvs." "Please use `env.reset(seed=seed) instead in VectorEnvs."
) )
def _add_info(self, infos: dict, info: dict, env_num: int) -> dict:
"""Add env info to the info dictionary of the vectorized environment.
Given the `info` of a single environment add it to the `infos` dictionary
which represents all the infos of the vectorized environment.
Every `key` of `info` is paired with a boolean mask `_key` representing
whether or not the i-indexed environment has this `info`.
Args:
infos (dict): the infos of the vectorized environment
info (dict): the info coming from the single environment
env_num (int): the index of the single environment
Returns:
infos (dict): the (updated) infos of the vectorized environment
"""
for k in info.keys():
if k not in infos:
info_array, array_mask = self._init_info_arrays(type(info[k]))
else:
info_array, array_mask = infos[k], infos[f"_{k}"]
info_array[env_num], array_mask[env_num] = info[k], True
infos[k], infos[f"_{k}"] = info_array, array_mask
return infos
def _init_info_arrays(self, dtype: type) -> np.ndarray:
"""Initialize the info array.
Initialize the info array. If the dtype is numeric
the info array will have the same dtype, otherwise
will be an array of `None`. Also, a boolean array
of the same length is returned. It will be used for
assessing which environment has info data.
Args:
dtype (type): data type of the info coming from the env.
Returns:
array (np.ndarray): the initialized info array.
array_mask (np.ndarray): the initialized boolean array.
"""
if dtype in [int, float, bool] or issubclass(dtype, np.number):
array = np.zeros(self.num_envs, dtype=dtype)
else:
array = np.zeros(self.num_envs, dtype=object)
array[:] = None
array_mask = np.zeros(self.num_envs, dtype=bool)
return array, array_mask
def __del__(self): def __del__(self):
"""Closes the vector environment.""" """Closes the vector environment."""
if not getattr(self, "closed", True): if not getattr(self, "closed", True):

View File

@@ -1,4 +1,5 @@
"""Module of wrapper classes.""" """Module of wrapper classes."""
from gym import error
from gym.wrappers.atari_preprocessing import AtariPreprocessing from gym.wrappers.atari_preprocessing import AtariPreprocessing
from gym.wrappers.autoreset import AutoResetWrapper from gym.wrappers.autoreset import AutoResetWrapper
from gym.wrappers.clip_action import ClipAction from gym.wrappers.clip_action import ClipAction
@@ -8,7 +9,6 @@ from gym.wrappers.frame_stack import FrameStack, LazyFrames
from gym.wrappers.gray_scale_observation import GrayScaleObservation from gym.wrappers.gray_scale_observation import GrayScaleObservation
from gym.wrappers.normalize import NormalizeObservation, NormalizeReward from gym.wrappers.normalize import NormalizeObservation, NormalizeReward
from gym.wrappers.order_enforcing import OrderEnforcing from gym.wrappers.order_enforcing import OrderEnforcing
from gym.wrappers.pixel_observation import PixelObservationWrapper
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
from gym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule from gym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule
from gym.wrappers.rescale_action import RescaleAction from gym.wrappers.rescale_action import RescaleAction
@@ -17,3 +17,4 @@ from gym.wrappers.time_aware_observation import TimeAwareObservation
from gym.wrappers.time_limit import TimeLimit from gym.wrappers.time_limit import TimeLimit
from gym.wrappers.transform_observation import TransformObservation from gym.wrappers.transform_observation import TransformObservation
from gym.wrappers.transform_reward import TransformReward from gym.wrappers.transform_reward import TransformReward
from gym.wrappers.vector_list_info import VectorListInfo

View File

@@ -7,11 +7,44 @@ import numpy as np
import gym import gym
def add_vector_episode_statistics(
info: dict, episode_info: dict, num_envs: int, env_num: int
):
"""Add episode statistics.
Add statistics coming from the vectorized environment.
Args:
info (dict): info dict of the environment.
episode_info (dict): episode statistics data.
num_envs (int): number of environments.
env_num (int): env number of the vectorized environments.
Returns:
info (dict): the input info dict with the episode statistics.
"""
info["episode"] = info.get("episode", {})
info["_episode"] = info.get("_episode", np.zeros(num_envs, dtype=bool))
info["_episode"][env_num] = True
for k in episode_info.keys():
info_array = info["episode"].get(k, np.zeros(num_envs))
info_array[env_num] = episode_info[k]
info["episode"][k] = info_array
return info
class RecordEpisodeStatistics(gym.Wrapper): class RecordEpisodeStatistics(gym.Wrapper):
"""This wrapper will keep track of cumulative rewards and episode lengths. """This wrapper will keep track of cumulative rewards and episode lengths.
At the end of an episode, the statistics of the episode will be added to ``info``. After the completion At the end of an episode, the statistics of the episode will be added to ``info``
of an episode, ``info`` will look like this:: using the key ``episode``. If using a vectorized environment also the key
``_episode`` is used which indicates whether the env at the respective index has
the episode statistics.
After the completion of an episode, ``info`` will look like this::
>>> info = { >>> info = {
... ... ... ...
@@ -22,6 +55,18 @@ class RecordEpisodeStatistics(gym.Wrapper):
... }, ... },
... } ... }
For a vectorized environments the output will be in the form of::
>>> infos = {
... ...
... "episode": {
... "r": "<array of cumulative reward>",
... "l": "<array of episode length>",
... "t": "<array of elapsed time since instantiation of wrapper>"
... },
... "_episode": "<boolean array of length num-envs>"
... }
Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via
:attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively. :attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively.
@@ -57,34 +102,40 @@ class RecordEpisodeStatistics(gym.Wrapper):
def step(self, action): def step(self, action):
"""Steps through the environment, recording the episode statistics.""" """Steps through the environment, recording the episode statistics."""
observations, rewards, dones, infos = super().step(action) observations, rewards, dones, infos = super().step(action)
assert isinstance(
infos, dict
), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order."
self.episode_returns += rewards self.episode_returns += rewards
self.episode_lengths += 1 self.episode_lengths += 1
if not self.is_vector_env: if not self.is_vector_env:
infos = [infos]
dones = [dones] dones = [dones]
else: dones = list(dones)
infos = list(infos) # Convert infos to mutable type
for i in range(len(dones)): for i in range(len(dones)):
if dones[i]: if dones[i]:
infos[i] = infos[i].copy()
episode_return = self.episode_returns[i] episode_return = self.episode_returns[i]
episode_length = self.episode_lengths[i] episode_length = self.episode_lengths[i]
episode_info = { episode_info = {
"r": episode_return, "episode": {
"l": episode_length, "r": episode_return,
"t": round(time.perf_counter() - self.t0, 6), "l": episode_length,
"t": round(time.perf_counter() - self.t0, 6),
}
} }
infos[i]["episode"] = episode_info if self.is_vector_env:
infos = add_vector_episode_statistics(
infos, episode_info["episode"], self.num_envs, i
)
else:
infos = {**infos, **episode_info}
self.return_queue.append(episode_return) self.return_queue.append(episode_return)
self.length_queue.append(episode_length) self.length_queue.append(episode_length)
self.episode_count += 1 self.episode_count += 1
self.episode_returns[i] = 0 self.episode_returns[i] = 0
self.episode_lengths[i] = 0 self.episode_lengths[i] = 0
if self.is_vector_env:
infos = tuple(infos)
return ( return (
observations, observations,
rewards, rewards,
dones if self.is_vector_env else dones[0], dones if self.is_vector_env else dones[0],
infos if self.is_vector_env else infos[0], infos,
) )

View File

@@ -0,0 +1,114 @@
"""Wrapper that converts the info format for vec envs into the list format."""
from typing import List
import gym
class VectorListInfo(gym.Wrapper):
"""Converts infos of vectorized envinroments from dict to List[dict].
This wrapper converts the info format of a
vector environment from a dictionary to a list of dictionaries.
This wrapper is intended to be used around vectorized
environments. If using other wrappers that perform
operation on info like `RecordEpisodeStatistics` this
need to be the outermost wrapper.
i.e. VectorListInfo(RecordEpisodeStatistics(envs))
Example::
>>> # actual
>>> {
... "k": np.array[0., 0., 0.5, 0.3],
... "_k": np.array[False, False, True, True]
... }
>>> # classic
>>> [{}, {}, {k: 0.5}, {k: 0.3}]
"""
def __init__(self, env):
"""This wrapper will convert the info into the list format.
Args:
env (Env): The environment to apply the wrapper
"""
assert getattr(
env, "is_vector_env", False
), "This wrapper can only be used in vectorized environments."
super().__init__(env)
def step(self, action):
"""Steps through the environment, convert dict info to list."""
observation, reward, done, infos = self.env.step(action)
list_info = self._convert_info_to_list(infos)
return observation, reward, done, list_info
def reset(self, **kwargs):
"""Resets the environment using kwargs."""
if not kwargs.get("return_info"):
return self.env.reset(**kwargs)
obs, infos = self.env.reset(**kwargs)
list_info = self._convert_info_to_list(infos)
return obs, list_info
def _convert_info_to_list(self, infos: dict) -> List[dict]:
"""Convert the dict info to list.
Convert the dict info of the vectorized environment
into a list of dictionaries where the i-th dictionary
has the info of the i-th environment.
Args:
infos (dict): info dict coming from the env.
Returns:
list_info (list): converted info.
"""
list_info = [{} for _ in range(self.num_envs)]
list_info = self._process_episode_statistics(infos, list_info)
for k in infos:
if k.startswith("_"):
continue
for i, has_info in enumerate(infos[f"_{k}"]):
if has_info:
list_info[i][k] = infos[k][i]
return list_info
def _process_episode_statistics(self, infos: dict, list_info: list) -> List[dict]:
"""Process episode statistics.
`RecordEpisodeStatistics` wrapper add extra
information to the info. This information are in
the form of a dict of dict. This method process these
information and add them to the info.
`RecordEpisodeStatistics` info contains the keys
"r", "l", "t" which represents "cumulative reward",
"episode length", "elapsed time since instantiation of wrapper".
Args:
infos (dict): infos coming from `RecordEpisodeStatistics`.
list_info (list): info of the current vectorized environment.
Returns:
list_info (list): updated info.
"""
episode_statistics = infos.pop("episode", False)
if not episode_statistics:
return list_info
episode_statistics_mask = infos.pop("_episode")
for i, has_info in enumerate(episode_statistics_mask):
if has_info:
list_info[i]["episode"] = {}
list_info[i]["episode"]["r"] = episode_statistics["r"][i]
list_info[i]["episode"]["l"] = episode_statistics["l"][i]
list_info[i]["episode"]["t"] = episode_statistics["t"][i]
return list_info

View File

@@ -63,7 +63,7 @@ def test_reset_async_vector_env(shared_memory):
assert observations.dtype == env.observation_space.dtype assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape assert observations.shape == env.observation_space.shape
assert isinstance(infos, list) assert isinstance(infos, dict)
assert all([isinstance(info, dict) for info in infos]) assert all([isinstance(info, dict) for info in infos])

View File

@@ -65,7 +65,7 @@ def test_reset_sync_vector_env():
assert observations.dtype == env.observation_space.dtype assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape assert observations.shape == env.observation_space.shape
assert isinstance(infos, list) assert isinstance(infos, dict)
assert all([isinstance(info, dict) for info in infos]) assert all([isinstance(info, dict) for info in infos])

View File

@@ -35,11 +35,11 @@ def test_vector_env_equal(shared_memory):
sync_observations, sync_rewards, sync_dones, sync_infos = sync_env.step(actions) sync_observations, sync_rewards, sync_dones, sync_infos = sync_env.step(actions)
# fmt: on # fmt: on
for idx in range(len(sync_dones)): if any(sync_dones):
if sync_dones[idx]: assert "terminal_observation" in async_infos
assert "terminal_observation" in async_infos[idx] assert "_terminal_observation" in async_infos
assert "terminal_observation" in sync_infos[idx] assert "terminal_observation" in sync_infos
assert sync_dones[idx] assert "_terminal_observation" in sync_infos
assert np.all(async_observations == sync_observations) assert np.all(async_observations == sync_observations)
assert np.all(async_rewards == sync_rewards) assert np.all(async_rewards == sync_rewards)

View File

@@ -0,0 +1,54 @@
import numpy as np
import pytest
import gym
from gym.vector.sync_vector_env import SyncVectorEnv
from tests.vector.utils import make_env
ENV_ID = "CartPole-v1"
NUM_ENVS = 3
ENV_STEPS = 50
SEED = 42
@pytest.mark.parametrize("asynchronous", [True, False])
def test_vector_env_info(asynchronous):
env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, asynchronous=asynchronous)
env.reset(seed=SEED)
for _ in range(ENV_STEPS):
env.action_space.seed(SEED)
action = env.action_space.sample()
_, _, dones, infos = env.step(action)
if any(dones):
assert len(infos["terminal_observation"]) == NUM_ENVS
assert len(infos["_terminal_observation"]) == NUM_ENVS
assert isinstance(infos["terminal_observation"], np.ndarray)
assert isinstance(infos["_terminal_observation"], np.ndarray)
for i, done in enumerate(dones):
if done:
assert infos["_terminal_observation"][i]
else:
assert not infos["_terminal_observation"][i]
assert infos["terminal_observation"][i] is None
@pytest.mark.parametrize("concurrent_ends", [1, 2, 3])
def test_vector_env_info_concurrent_termination(concurrent_ends):
# envs that need to terminate together will have the same action
actions = [0] * concurrent_ends + [1] * (NUM_ENVS - concurrent_ends)
envs = [make_env(ENV_ID, SEED) for _ in range(NUM_ENVS)]
envs = SyncVectorEnv(envs)
for _ in range(ENV_STEPS):
_, _, dones, infos = envs.step(actions)
if any(dones):
for i, done in enumerate(dones):
if i < concurrent_ends:
assert done
assert infos["_terminal_observation"][i]
else:
assert not infos["_terminal_observation"][i]
assert infos["terminal_observation"][i] is None
return

View File

@@ -1,7 +1,9 @@
import numpy as np
import pytest import pytest
import gym import gym
from gym.wrappers import RecordEpisodeStatistics from gym.wrappers import RecordEpisodeStatistics, VectorListInfo
from gym.wrappers.record_episode_statistics import add_vector_episode_statistics
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
@@ -50,8 +52,47 @@ def test_record_episode_statistics_with_vectorenv(num_envs, asynchronous):
envs.reset() envs.reset()
for _ in range(max_episode_step + 1): for _ in range(max_episode_step + 1):
_, _, dones, infos = envs.step(envs.action_space.sample()) _, _, dones, infos = envs.step(envs.action_space.sample())
for idx, info in enumerate(infos): if any(dones):
if dones[idx]: assert "episode" in infos
assert "episode" in info assert "_episode" in infos
assert all([item in info["episode"] for item in ["r", "l", "t"]]) assert all(infos["_episode"] == dones)
break assert all([item in infos["episode"] for item in ["r", "l", "t"]])
break
else:
assert "episode" not in infos
assert "_episode" not in infos
def test_wrong_wrapping_order():
envs = gym.vector.make("CartPole-v1", num_envs=3)
wrapped_env = RecordEpisodeStatistics(VectorListInfo(envs))
wrapped_env.reset()
with pytest.raises(AssertionError):
wrapped_env.step(wrapped_env.action_space.sample())
def test_add_vector_episode_statistics():
NUM_ENVS = 5
info = {}
for i in range(NUM_ENVS):
episode_info = {
"episode": {
"r": i,
"l": i,
"t": i,
}
}
info = add_vector_episode_statistics(info, episode_info["episode"], NUM_ENVS, i)
assert np.alltrue(info["_episode"][: i + 1])
for j in range(NUM_ENVS):
if j <= i:
assert info["episode"]["r"][j] == j
assert info["episode"]["l"][j] == j
assert info["episode"]["t"][j] == j
else:
assert info["episode"]["r"][j] == 0
assert info["episode"]["l"][j] == 0
assert info["episode"]["t"][j] == 0

View File

@@ -90,10 +90,11 @@ def test_record_video_within_vector():
envs.reset() envs.reset()
for i in range(199): for i in range(199):
_, _, _, infos = envs.step(envs.action_space.sample()) _, _, _, infos = envs.step(envs.action_space.sample())
for info in infos:
if "episode" in info.keys(): # break when every env is done
print(f"episode_reward={info['episode']['r']}") if "episode" in infos and all(infos["_episode"]):
break print(f"episode_reward={infos['episode']['r']}")
assert os.path.isdir("videos") assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
assert len(mp4_files) == 2 assert len(mp4_files) == 2

View File

@@ -0,0 +1,58 @@
import pytest
import gym
from gym.wrappers import RecordEpisodeStatistics, VectorListInfo
ENV_ID = "CartPole-v1"
NUM_ENVS = 3
ENV_STEPS = 50
SEED = 42
def test_usage_in_vector_env():
env = gym.make(ENV_ID)
vector_env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS)
VectorListInfo(vector_env)
with pytest.raises(AssertionError):
VectorListInfo(env)
def test_info_to_list():
env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS)
wrapped_env = VectorListInfo(env_to_wrap)
wrapped_env.action_space.seed(SEED)
_, info = wrapped_env.reset(seed=SEED, return_info=True)
assert isinstance(info, list)
assert len(info) == NUM_ENVS
for _ in range(ENV_STEPS):
action = wrapped_env.action_space.sample()
_, _, dones, list_info = wrapped_env.step(action)
for i, done in enumerate(dones):
if done:
assert "terminal_observation" in list_info[i]
else:
assert "terminal_observation" not in list_info[i]
def test_info_to_list_statistics():
env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS)
wrapped_env = VectorListInfo(RecordEpisodeStatistics(env_to_wrap))
_, info = wrapped_env.reset(seed=SEED, return_info=True)
wrapped_env.action_space.seed(SEED)
assert isinstance(info, list)
assert len(info) == NUM_ENVS
for _ in range(ENV_STEPS):
action = wrapped_env.action_space.sample()
_, _, dones, list_info = wrapped_env.step(action)
for i, done in enumerate(dones):
if done:
assert "episode" in list_info[i]
for stats in ["r", "l", "t"]:
assert stats in list_info[i]["episode"]
assert isinstance(list_info[i]["episode"][stats], float)
else:
assert "episode" not in list_info[i]