New info API for vectorized environments #2657 (#2773)

* WIP refactor info API sync vector.

* Add missing untracked file.

* Add info strategy to reset_wait.

* Add interface and docstring.

* info with strategy pattern on async vector env.

* Add default to async vecenv.

* episode statistics for asyncvecnev.

* Add tests info strategy format.

* Add info strategy to reset_wait.

* refactor and cleanup.

* Code cleanup. Add tests.

* Add tests for video recording with new info format.

* fix test case.

* fix camelcase.

* rename enum.

* update tests, docstrings, cleanup.

* Changes brax strategy to numpy. add_strategy method in StrategyFactory. Add tests.

* fix docstring and logging format.

* Set Brax info format as default. Remove classic info format. Update tests.

* breaking the wrong loop.

* WIP: wrapper.

* Add wrapper for brax to classic info.

* WIP: wrapper with nested RecordEpisodeStatistic.

* Add tests. Refactor docstrings. Cleanup.

* cleanup.

* patch conflicts.

* rebase and conflicts.

* new pre-commit conventions.

* docstring.

* renaming.

* incorporate info_processor in vecEnv.

* renaming. Create info dict only if needed.

* remove all brax references. update docstring. Update duplicate test.

* reviews.

* pre-commit.

* reviews.

* docstring.

* cleanup blank lines.

* add support for numpy dtypes.

* docstring fix.

* formatting.

* naming.

* assert correct info from wrappers chaining. Test correct wrappers chaining. naming.

* simplify episode_statistics.

* change args orer.

* update tests.

* wip: refactor episode_statistics.

* Add test for add_vecore_episode_statistics.
This commit is contained in:
Gianluca De Cola
2022-05-24 16:36:35 +02:00
committed by GitHub
parent bbf8f5a467
commit 49d8299a1e
13 changed files with 428 additions and 42 deletions

View File

@@ -271,8 +271,10 @@ class AsyncVectorEnv(VectorEnv):
self._state = AsyncState.DEFAULT
if return_info:
results, infos = zip(*results)
infos = list(infos)
infos = {}
results, info_data = zip(*results)
for i, info in enumerate(info_data):
infos = self._add_info(infos, info, i)
if not self.shared_memory:
self.observations = concatenate(
@@ -344,10 +346,20 @@ class AsyncVectorEnv(VectorEnv):
f"The call to `step_wait` has timed out after {timeout} second(s)."
)
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
observations_list, rewards, dones, infos = [], [], [], {}
successes = []
for i, pipe in enumerate(self.parent_pipes):
result, success = pipe.recv()
obs, rew, done, info = result
successes.append(success)
observations_list.append(obs)
rewards.append(rew)
dones.append(done)
infos = self._add_info(infos, info, i)
self._raise_if_errors(successes)
self._state = AsyncState.DEFAULT
observations_list, rewards, dones, infos = zip(*results)
if not self.shared_memory:
self.observations = concatenate(

View File

@@ -108,8 +108,8 @@ class SyncVectorEnv(VectorEnv):
self._dones[:] = False
observations = []
data_list = []
for env, single_seed in zip(self.envs, seed):
infos = {}
for i, (env, single_seed) in enumerate(zip(self.envs, seed)):
kwargs = {}
if single_seed is not None:
@@ -123,9 +123,9 @@ class SyncVectorEnv(VectorEnv):
observation = env.reset(**kwargs)
observations.append(observation)
else:
observation, data = env.reset(**kwargs)
observation, info = env.reset(**kwargs)
observations.append(observation)
data_list.append(data)
infos = self._add_info(infos, info, i)
self.observations = concatenate(
self.single_observation_space, observations, self.observations
@@ -135,7 +135,7 @@ class SyncVectorEnv(VectorEnv):
else:
return (
deepcopy(self.observations) if self.copy else self.observations
), data_list
), infos
def step_async(self, actions):
"""Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version."""
@@ -147,14 +147,14 @@ class SyncVectorEnv(VectorEnv):
Returns:
The batched environment step results
"""
observations, infos = [], []
observations, infos = [], {}
for i, (env, action) in enumerate(zip(self.envs, self._actions)):
observation, self._rewards[i], self._dones[i], info = env.step(action)
if self._dones[i]:
info["terminal_observation"] = observation
observation = env.reset()
observations.append(observation)
infos.append(info)
infos = self._add_info(infos, info, i)
self.observations = concatenate(
self.single_observation_space, observations, self.observations
)

View File

@@ -3,6 +3,8 @@ from __future__ import annotations
from typing import Any, Optional, Union
import numpy as np
import gym
from gym.logger import deprecation
from gym.vector.utils.spaces import batch_space
@@ -201,6 +203,58 @@ class VectorEnv(gym.Env):
"Please use `env.reset(seed=seed) instead in VectorEnvs."
)
def _add_info(self, infos: dict, info: dict, env_num: int) -> dict:
"""Add env info to the info dictionary of the vectorized environment.
Given the `info` of a single environment add it to the `infos` dictionary
which represents all the infos of the vectorized environment.
Every `key` of `info` is paired with a boolean mask `_key` representing
whether or not the i-indexed environment has this `info`.
Args:
infos (dict): the infos of the vectorized environment
info (dict): the info coming from the single environment
env_num (int): the index of the single environment
Returns:
infos (dict): the (updated) infos of the vectorized environment
"""
for k in info.keys():
if k not in infos:
info_array, array_mask = self._init_info_arrays(type(info[k]))
else:
info_array, array_mask = infos[k], infos[f"_{k}"]
info_array[env_num], array_mask[env_num] = info[k], True
infos[k], infos[f"_{k}"] = info_array, array_mask
return infos
def _init_info_arrays(self, dtype: type) -> np.ndarray:
"""Initialize the info array.
Initialize the info array. If the dtype is numeric
the info array will have the same dtype, otherwise
will be an array of `None`. Also, a boolean array
of the same length is returned. It will be used for
assessing which environment has info data.
Args:
dtype (type): data type of the info coming from the env.
Returns:
array (np.ndarray): the initialized info array.
array_mask (np.ndarray): the initialized boolean array.
"""
if dtype in [int, float, bool] or issubclass(dtype, np.number):
array = np.zeros(self.num_envs, dtype=dtype)
else:
array = np.zeros(self.num_envs, dtype=object)
array[:] = None
array_mask = np.zeros(self.num_envs, dtype=bool)
return array, array_mask
def __del__(self):
"""Closes the vector environment."""
if not getattr(self, "closed", True):

View File

@@ -1,4 +1,5 @@
"""Module of wrapper classes."""
from gym import error
from gym.wrappers.atari_preprocessing import AtariPreprocessing
from gym.wrappers.autoreset import AutoResetWrapper
from gym.wrappers.clip_action import ClipAction
@@ -8,7 +9,6 @@ from gym.wrappers.frame_stack import FrameStack, LazyFrames
from gym.wrappers.gray_scale_observation import GrayScaleObservation
from gym.wrappers.normalize import NormalizeObservation, NormalizeReward
from gym.wrappers.order_enforcing import OrderEnforcing
from gym.wrappers.pixel_observation import PixelObservationWrapper
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
from gym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule
from gym.wrappers.rescale_action import RescaleAction
@@ -17,3 +17,4 @@ from gym.wrappers.time_aware_observation import TimeAwareObservation
from gym.wrappers.time_limit import TimeLimit
from gym.wrappers.transform_observation import TransformObservation
from gym.wrappers.transform_reward import TransformReward
from gym.wrappers.vector_list_info import VectorListInfo

View File

@@ -7,11 +7,44 @@ import numpy as np
import gym
def add_vector_episode_statistics(
info: dict, episode_info: dict, num_envs: int, env_num: int
):
"""Add episode statistics.
Add statistics coming from the vectorized environment.
Args:
info (dict): info dict of the environment.
episode_info (dict): episode statistics data.
num_envs (int): number of environments.
env_num (int): env number of the vectorized environments.
Returns:
info (dict): the input info dict with the episode statistics.
"""
info["episode"] = info.get("episode", {})
info["_episode"] = info.get("_episode", np.zeros(num_envs, dtype=bool))
info["_episode"][env_num] = True
for k in episode_info.keys():
info_array = info["episode"].get(k, np.zeros(num_envs))
info_array[env_num] = episode_info[k]
info["episode"][k] = info_array
return info
class RecordEpisodeStatistics(gym.Wrapper):
"""This wrapper will keep track of cumulative rewards and episode lengths.
At the end of an episode, the statistics of the episode will be added to ``info``. After the completion
of an episode, ``info`` will look like this::
At the end of an episode, the statistics of the episode will be added to ``info``
using the key ``episode``. If using a vectorized environment also the key
``_episode`` is used which indicates whether the env at the respective index has
the episode statistics.
After the completion of an episode, ``info`` will look like this::
>>> info = {
... ...
@@ -22,6 +55,18 @@ class RecordEpisodeStatistics(gym.Wrapper):
... },
... }
For a vectorized environments the output will be in the form of::
>>> infos = {
... ...
... "episode": {
... "r": "<array of cumulative reward>",
... "l": "<array of episode length>",
... "t": "<array of elapsed time since instantiation of wrapper>"
... },
... "_episode": "<boolean array of length num-envs>"
... }
Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via
:attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively.
@@ -57,34 +102,40 @@ class RecordEpisodeStatistics(gym.Wrapper):
def step(self, action):
"""Steps through the environment, recording the episode statistics."""
observations, rewards, dones, infos = super().step(action)
assert isinstance(
infos, dict
), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order."
self.episode_returns += rewards
self.episode_lengths += 1
if not self.is_vector_env:
infos = [infos]
dones = [dones]
else:
infos = list(infos) # Convert infos to mutable type
dones = list(dones)
for i in range(len(dones)):
if dones[i]:
infos[i] = infos[i].copy()
episode_return = self.episode_returns[i]
episode_length = self.episode_lengths[i]
episode_info = {
"r": episode_return,
"l": episode_length,
"t": round(time.perf_counter() - self.t0, 6),
"episode": {
"r": episode_return,
"l": episode_length,
"t": round(time.perf_counter() - self.t0, 6),
}
}
infos[i]["episode"] = episode_info
if self.is_vector_env:
infos = add_vector_episode_statistics(
infos, episode_info["episode"], self.num_envs, i
)
else:
infos = {**infos, **episode_info}
self.return_queue.append(episode_return)
self.length_queue.append(episode_length)
self.episode_count += 1
self.episode_returns[i] = 0
self.episode_lengths[i] = 0
if self.is_vector_env:
infos = tuple(infos)
return (
observations,
rewards,
dones if self.is_vector_env else dones[0],
infos if self.is_vector_env else infos[0],
infos,
)

View File

@@ -0,0 +1,114 @@
"""Wrapper that converts the info format for vec envs into the list format."""
from typing import List
import gym
class VectorListInfo(gym.Wrapper):
"""Converts infos of vectorized envinroments from dict to List[dict].
This wrapper converts the info format of a
vector environment from a dictionary to a list of dictionaries.
This wrapper is intended to be used around vectorized
environments. If using other wrappers that perform
operation on info like `RecordEpisodeStatistics` this
need to be the outermost wrapper.
i.e. VectorListInfo(RecordEpisodeStatistics(envs))
Example::
>>> # actual
>>> {
... "k": np.array[0., 0., 0.5, 0.3],
... "_k": np.array[False, False, True, True]
... }
>>> # classic
>>> [{}, {}, {k: 0.5}, {k: 0.3}]
"""
def __init__(self, env):
"""This wrapper will convert the info into the list format.
Args:
env (Env): The environment to apply the wrapper
"""
assert getattr(
env, "is_vector_env", False
), "This wrapper can only be used in vectorized environments."
super().__init__(env)
def step(self, action):
"""Steps through the environment, convert dict info to list."""
observation, reward, done, infos = self.env.step(action)
list_info = self._convert_info_to_list(infos)
return observation, reward, done, list_info
def reset(self, **kwargs):
"""Resets the environment using kwargs."""
if not kwargs.get("return_info"):
return self.env.reset(**kwargs)
obs, infos = self.env.reset(**kwargs)
list_info = self._convert_info_to_list(infos)
return obs, list_info
def _convert_info_to_list(self, infos: dict) -> List[dict]:
"""Convert the dict info to list.
Convert the dict info of the vectorized environment
into a list of dictionaries where the i-th dictionary
has the info of the i-th environment.
Args:
infos (dict): info dict coming from the env.
Returns:
list_info (list): converted info.
"""
list_info = [{} for _ in range(self.num_envs)]
list_info = self._process_episode_statistics(infos, list_info)
for k in infos:
if k.startswith("_"):
continue
for i, has_info in enumerate(infos[f"_{k}"]):
if has_info:
list_info[i][k] = infos[k][i]
return list_info
def _process_episode_statistics(self, infos: dict, list_info: list) -> List[dict]:
"""Process episode statistics.
`RecordEpisodeStatistics` wrapper add extra
information to the info. This information are in
the form of a dict of dict. This method process these
information and add them to the info.
`RecordEpisodeStatistics` info contains the keys
"r", "l", "t" which represents "cumulative reward",
"episode length", "elapsed time since instantiation of wrapper".
Args:
infos (dict): infos coming from `RecordEpisodeStatistics`.
list_info (list): info of the current vectorized environment.
Returns:
list_info (list): updated info.
"""
episode_statistics = infos.pop("episode", False)
if not episode_statistics:
return list_info
episode_statistics_mask = infos.pop("_episode")
for i, has_info in enumerate(episode_statistics_mask):
if has_info:
list_info[i]["episode"] = {}
list_info[i]["episode"]["r"] = episode_statistics["r"][i]
list_info[i]["episode"]["l"] = episode_statistics["l"][i]
list_info[i]["episode"]["t"] = episode_statistics["t"][i]
return list_info

View File

@@ -63,7 +63,7 @@ def test_reset_async_vector_env(shared_memory):
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
assert isinstance(infos, list)
assert isinstance(infos, dict)
assert all([isinstance(info, dict) for info in infos])

View File

@@ -65,7 +65,7 @@ def test_reset_sync_vector_env():
assert observations.dtype == env.observation_space.dtype
assert observations.shape == (8,) + env.single_observation_space.shape
assert observations.shape == env.observation_space.shape
assert isinstance(infos, list)
assert isinstance(infos, dict)
assert all([isinstance(info, dict) for info in infos])

View File

@@ -35,11 +35,11 @@ def test_vector_env_equal(shared_memory):
sync_observations, sync_rewards, sync_dones, sync_infos = sync_env.step(actions)
# fmt: on
for idx in range(len(sync_dones)):
if sync_dones[idx]:
assert "terminal_observation" in async_infos[idx]
assert "terminal_observation" in sync_infos[idx]
assert sync_dones[idx]
if any(sync_dones):
assert "terminal_observation" in async_infos
assert "_terminal_observation" in async_infos
assert "terminal_observation" in sync_infos
assert "_terminal_observation" in sync_infos
assert np.all(async_observations == sync_observations)
assert np.all(async_rewards == sync_rewards)

View File

@@ -0,0 +1,54 @@
import numpy as np
import pytest
import gym
from gym.vector.sync_vector_env import SyncVectorEnv
from tests.vector.utils import make_env
ENV_ID = "CartPole-v1"
NUM_ENVS = 3
ENV_STEPS = 50
SEED = 42
@pytest.mark.parametrize("asynchronous", [True, False])
def test_vector_env_info(asynchronous):
env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, asynchronous=asynchronous)
env.reset(seed=SEED)
for _ in range(ENV_STEPS):
env.action_space.seed(SEED)
action = env.action_space.sample()
_, _, dones, infos = env.step(action)
if any(dones):
assert len(infos["terminal_observation"]) == NUM_ENVS
assert len(infos["_terminal_observation"]) == NUM_ENVS
assert isinstance(infos["terminal_observation"], np.ndarray)
assert isinstance(infos["_terminal_observation"], np.ndarray)
for i, done in enumerate(dones):
if done:
assert infos["_terminal_observation"][i]
else:
assert not infos["_terminal_observation"][i]
assert infos["terminal_observation"][i] is None
@pytest.mark.parametrize("concurrent_ends", [1, 2, 3])
def test_vector_env_info_concurrent_termination(concurrent_ends):
# envs that need to terminate together will have the same action
actions = [0] * concurrent_ends + [1] * (NUM_ENVS - concurrent_ends)
envs = [make_env(ENV_ID, SEED) for _ in range(NUM_ENVS)]
envs = SyncVectorEnv(envs)
for _ in range(ENV_STEPS):
_, _, dones, infos = envs.step(actions)
if any(dones):
for i, done in enumerate(dones):
if i < concurrent_ends:
assert done
assert infos["_terminal_observation"][i]
else:
assert not infos["_terminal_observation"][i]
assert infos["terminal_observation"][i] is None
return

View File

@@ -1,7 +1,9 @@
import numpy as np
import pytest
import gym
from gym.wrappers import RecordEpisodeStatistics
from gym.wrappers import RecordEpisodeStatistics, VectorListInfo
from gym.wrappers.record_episode_statistics import add_vector_episode_statistics
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
@@ -50,8 +52,47 @@ def test_record_episode_statistics_with_vectorenv(num_envs, asynchronous):
envs.reset()
for _ in range(max_episode_step + 1):
_, _, dones, infos = envs.step(envs.action_space.sample())
for idx, info in enumerate(infos):
if dones[idx]:
assert "episode" in info
assert all([item in info["episode"] for item in ["r", "l", "t"]])
break
if any(dones):
assert "episode" in infos
assert "_episode" in infos
assert all(infos["_episode"] == dones)
assert all([item in infos["episode"] for item in ["r", "l", "t"]])
break
else:
assert "episode" not in infos
assert "_episode" not in infos
def test_wrong_wrapping_order():
envs = gym.vector.make("CartPole-v1", num_envs=3)
wrapped_env = RecordEpisodeStatistics(VectorListInfo(envs))
wrapped_env.reset()
with pytest.raises(AssertionError):
wrapped_env.step(wrapped_env.action_space.sample())
def test_add_vector_episode_statistics():
NUM_ENVS = 5
info = {}
for i in range(NUM_ENVS):
episode_info = {
"episode": {
"r": i,
"l": i,
"t": i,
}
}
info = add_vector_episode_statistics(info, episode_info["episode"], NUM_ENVS, i)
assert np.alltrue(info["_episode"][: i + 1])
for j in range(NUM_ENVS):
if j <= i:
assert info["episode"]["r"][j] == j
assert info["episode"]["l"][j] == j
assert info["episode"]["t"][j] == j
else:
assert info["episode"]["r"][j] == 0
assert info["episode"]["l"][j] == 0
assert info["episode"]["t"][j] == 0

View File

@@ -90,10 +90,11 @@ def test_record_video_within_vector():
envs.reset()
for i in range(199):
_, _, _, infos = envs.step(envs.action_space.sample())
for info in infos:
if "episode" in info.keys():
print(f"episode_reward={info['episode']['r']}")
break
# break when every env is done
if "episode" in infos and all(infos["_episode"]):
print(f"episode_reward={infos['episode']['r']}")
assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
assert len(mp4_files) == 2

View File

@@ -0,0 +1,58 @@
import pytest
import gym
from gym.wrappers import RecordEpisodeStatistics, VectorListInfo
ENV_ID = "CartPole-v1"
NUM_ENVS = 3
ENV_STEPS = 50
SEED = 42
def test_usage_in_vector_env():
env = gym.make(ENV_ID)
vector_env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS)
VectorListInfo(vector_env)
with pytest.raises(AssertionError):
VectorListInfo(env)
def test_info_to_list():
env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS)
wrapped_env = VectorListInfo(env_to_wrap)
wrapped_env.action_space.seed(SEED)
_, info = wrapped_env.reset(seed=SEED, return_info=True)
assert isinstance(info, list)
assert len(info) == NUM_ENVS
for _ in range(ENV_STEPS):
action = wrapped_env.action_space.sample()
_, _, dones, list_info = wrapped_env.step(action)
for i, done in enumerate(dones):
if done:
assert "terminal_observation" in list_info[i]
else:
assert "terminal_observation" not in list_info[i]
def test_info_to_list_statistics():
env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS)
wrapped_env = VectorListInfo(RecordEpisodeStatistics(env_to_wrap))
_, info = wrapped_env.reset(seed=SEED, return_info=True)
wrapped_env.action_space.seed(SEED)
assert isinstance(info, list)
assert len(info) == NUM_ENVS
for _ in range(ENV_STEPS):
action = wrapped_env.action_space.sample()
_, _, dones, list_info = wrapped_env.step(action)
for i, done in enumerate(dones):
if done:
assert "episode" in list_info[i]
for stats in ["r", "l", "t"]:
assert stats in list_info[i]["episode"]
assert isinstance(list_info[i]["episode"][stats], float)
else:
assert "episode" not in list_info[i]