mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-19 13:32:03 +00:00
* WIP refactor info API sync vector. * Add missing untracked file. * Add info strategy to reset_wait. * Add interface and docstring. * info with strategy pattern on async vector env. * Add default to async vecenv. * episode statistics for asyncvecnev. * Add tests info strategy format. * Add info strategy to reset_wait. * refactor and cleanup. * Code cleanup. Add tests. * Add tests for video recording with new info format. * fix test case. * fix camelcase. * rename enum. * update tests, docstrings, cleanup. * Changes brax strategy to numpy. add_strategy method in StrategyFactory. Add tests. * fix docstring and logging format. * Set Brax info format as default. Remove classic info format. Update tests. * breaking the wrong loop. * WIP: wrapper. * Add wrapper for brax to classic info. * WIP: wrapper with nested RecordEpisodeStatistic. * Add tests. Refactor docstrings. Cleanup. * cleanup. * patch conflicts. * rebase and conflicts. * new pre-commit conventions. * docstring. * renaming. * incorporate info_processor in vecEnv. * renaming. Create info dict only if needed. * remove all brax references. update docstring. Update duplicate test. * reviews. * pre-commit. * reviews. * docstring. * cleanup blank lines. * add support for numpy dtypes. * docstring fix. * formatting. * naming. * assert correct info from wrappers chaining. Test correct wrappers chaining. naming. * simplify episode_statistics. * change args orer. * update tests. * wip: refactor episode_statistics. * Add test for add_vecore_episode_statistics.
This commit is contained in:
@@ -271,8 +271,10 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
self._state = AsyncState.DEFAULT
|
self._state = AsyncState.DEFAULT
|
||||||
|
|
||||||
if return_info:
|
if return_info:
|
||||||
results, infos = zip(*results)
|
infos = {}
|
||||||
infos = list(infos)
|
results, info_data = zip(*results)
|
||||||
|
for i, info in enumerate(info_data):
|
||||||
|
infos = self._add_info(infos, info, i)
|
||||||
|
|
||||||
if not self.shared_memory:
|
if not self.shared_memory:
|
||||||
self.observations = concatenate(
|
self.observations = concatenate(
|
||||||
@@ -344,10 +346,20 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
f"The call to `step_wait` has timed out after {timeout} second(s)."
|
f"The call to `step_wait` has timed out after {timeout} second(s)."
|
||||||
)
|
)
|
||||||
|
|
||||||
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
observations_list, rewards, dones, infos = [], [], [], {}
|
||||||
|
successes = []
|
||||||
|
for i, pipe in enumerate(self.parent_pipes):
|
||||||
|
result, success = pipe.recv()
|
||||||
|
obs, rew, done, info = result
|
||||||
|
|
||||||
|
successes.append(success)
|
||||||
|
observations_list.append(obs)
|
||||||
|
rewards.append(rew)
|
||||||
|
dones.append(done)
|
||||||
|
infos = self._add_info(infos, info, i)
|
||||||
|
|
||||||
self._raise_if_errors(successes)
|
self._raise_if_errors(successes)
|
||||||
self._state = AsyncState.DEFAULT
|
self._state = AsyncState.DEFAULT
|
||||||
observations_list, rewards, dones, infos = zip(*results)
|
|
||||||
|
|
||||||
if not self.shared_memory:
|
if not self.shared_memory:
|
||||||
self.observations = concatenate(
|
self.observations = concatenate(
|
||||||
|
@@ -108,8 +108,8 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
|
|
||||||
self._dones[:] = False
|
self._dones[:] = False
|
||||||
observations = []
|
observations = []
|
||||||
data_list = []
|
infos = {}
|
||||||
for env, single_seed in zip(self.envs, seed):
|
for i, (env, single_seed) in enumerate(zip(self.envs, seed)):
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if single_seed is not None:
|
if single_seed is not None:
|
||||||
@@ -123,9 +123,9 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
observation = env.reset(**kwargs)
|
observation = env.reset(**kwargs)
|
||||||
observations.append(observation)
|
observations.append(observation)
|
||||||
else:
|
else:
|
||||||
observation, data = env.reset(**kwargs)
|
observation, info = env.reset(**kwargs)
|
||||||
observations.append(observation)
|
observations.append(observation)
|
||||||
data_list.append(data)
|
infos = self._add_info(infos, info, i)
|
||||||
|
|
||||||
self.observations = concatenate(
|
self.observations = concatenate(
|
||||||
self.single_observation_space, observations, self.observations
|
self.single_observation_space, observations, self.observations
|
||||||
@@ -135,7 +135,7 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
else:
|
else:
|
||||||
return (
|
return (
|
||||||
deepcopy(self.observations) if self.copy else self.observations
|
deepcopy(self.observations) if self.copy else self.observations
|
||||||
), data_list
|
), infos
|
||||||
|
|
||||||
def step_async(self, actions):
|
def step_async(self, actions):
|
||||||
"""Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version."""
|
"""Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version."""
|
||||||
@@ -147,14 +147,14 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
Returns:
|
Returns:
|
||||||
The batched environment step results
|
The batched environment step results
|
||||||
"""
|
"""
|
||||||
observations, infos = [], []
|
observations, infos = [], {}
|
||||||
for i, (env, action) in enumerate(zip(self.envs, self._actions)):
|
for i, (env, action) in enumerate(zip(self.envs, self._actions)):
|
||||||
observation, self._rewards[i], self._dones[i], info = env.step(action)
|
observation, self._rewards[i], self._dones[i], info = env.step(action)
|
||||||
if self._dones[i]:
|
if self._dones[i]:
|
||||||
info["terminal_observation"] = observation
|
info["terminal_observation"] = observation
|
||||||
observation = env.reset()
|
observation = env.reset()
|
||||||
observations.append(observation)
|
observations.append(observation)
|
||||||
infos.append(info)
|
infos = self._add_info(infos, info, i)
|
||||||
self.observations = concatenate(
|
self.observations = concatenate(
|
||||||
self.single_observation_space, observations, self.observations
|
self.single_observation_space, observations, self.observations
|
||||||
)
|
)
|
||||||
|
@@ -3,6 +3,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym.logger import deprecation
|
from gym.logger import deprecation
|
||||||
from gym.vector.utils.spaces import batch_space
|
from gym.vector.utils.spaces import batch_space
|
||||||
@@ -201,6 +203,58 @@ class VectorEnv(gym.Env):
|
|||||||
"Please use `env.reset(seed=seed) instead in VectorEnvs."
|
"Please use `env.reset(seed=seed) instead in VectorEnvs."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _add_info(self, infos: dict, info: dict, env_num: int) -> dict:
|
||||||
|
"""Add env info to the info dictionary of the vectorized environment.
|
||||||
|
|
||||||
|
Given the `info` of a single environment add it to the `infos` dictionary
|
||||||
|
which represents all the infos of the vectorized environment.
|
||||||
|
Every `key` of `info` is paired with a boolean mask `_key` representing
|
||||||
|
whether or not the i-indexed environment has this `info`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
infos (dict): the infos of the vectorized environment
|
||||||
|
info (dict): the info coming from the single environment
|
||||||
|
env_num (int): the index of the single environment
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
infos (dict): the (updated) infos of the vectorized environment
|
||||||
|
|
||||||
|
"""
|
||||||
|
for k in info.keys():
|
||||||
|
if k not in infos:
|
||||||
|
info_array, array_mask = self._init_info_arrays(type(info[k]))
|
||||||
|
else:
|
||||||
|
info_array, array_mask = infos[k], infos[f"_{k}"]
|
||||||
|
|
||||||
|
info_array[env_num], array_mask[env_num] = info[k], True
|
||||||
|
infos[k], infos[f"_{k}"] = info_array, array_mask
|
||||||
|
return infos
|
||||||
|
|
||||||
|
def _init_info_arrays(self, dtype: type) -> np.ndarray:
|
||||||
|
"""Initialize the info array.
|
||||||
|
|
||||||
|
Initialize the info array. If the dtype is numeric
|
||||||
|
the info array will have the same dtype, otherwise
|
||||||
|
will be an array of `None`. Also, a boolean array
|
||||||
|
of the same length is returned. It will be used for
|
||||||
|
assessing which environment has info data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtype (type): data type of the info coming from the env.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array (np.ndarray): the initialized info array.
|
||||||
|
array_mask (np.ndarray): the initialized boolean array.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if dtype in [int, float, bool] or issubclass(dtype, np.number):
|
||||||
|
array = np.zeros(self.num_envs, dtype=dtype)
|
||||||
|
else:
|
||||||
|
array = np.zeros(self.num_envs, dtype=object)
|
||||||
|
array[:] = None
|
||||||
|
array_mask = np.zeros(self.num_envs, dtype=bool)
|
||||||
|
return array, array_mask
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
"""Closes the vector environment."""
|
"""Closes the vector environment."""
|
||||||
if not getattr(self, "closed", True):
|
if not getattr(self, "closed", True):
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
"""Module of wrapper classes."""
|
"""Module of wrapper classes."""
|
||||||
|
from gym import error
|
||||||
from gym.wrappers.atari_preprocessing import AtariPreprocessing
|
from gym.wrappers.atari_preprocessing import AtariPreprocessing
|
||||||
from gym.wrappers.autoreset import AutoResetWrapper
|
from gym.wrappers.autoreset import AutoResetWrapper
|
||||||
from gym.wrappers.clip_action import ClipAction
|
from gym.wrappers.clip_action import ClipAction
|
||||||
@@ -8,7 +9,6 @@ from gym.wrappers.frame_stack import FrameStack, LazyFrames
|
|||||||
from gym.wrappers.gray_scale_observation import GrayScaleObservation
|
from gym.wrappers.gray_scale_observation import GrayScaleObservation
|
||||||
from gym.wrappers.normalize import NormalizeObservation, NormalizeReward
|
from gym.wrappers.normalize import NormalizeObservation, NormalizeReward
|
||||||
from gym.wrappers.order_enforcing import OrderEnforcing
|
from gym.wrappers.order_enforcing import OrderEnforcing
|
||||||
from gym.wrappers.pixel_observation import PixelObservationWrapper
|
|
||||||
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
|
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
|
||||||
from gym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule
|
from gym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule
|
||||||
from gym.wrappers.rescale_action import RescaleAction
|
from gym.wrappers.rescale_action import RescaleAction
|
||||||
@@ -17,3 +17,4 @@ from gym.wrappers.time_aware_observation import TimeAwareObservation
|
|||||||
from gym.wrappers.time_limit import TimeLimit
|
from gym.wrappers.time_limit import TimeLimit
|
||||||
from gym.wrappers.transform_observation import TransformObservation
|
from gym.wrappers.transform_observation import TransformObservation
|
||||||
from gym.wrappers.transform_reward import TransformReward
|
from gym.wrappers.transform_reward import TransformReward
|
||||||
|
from gym.wrappers.vector_list_info import VectorListInfo
|
||||||
|
@@ -7,11 +7,44 @@ import numpy as np
|
|||||||
import gym
|
import gym
|
||||||
|
|
||||||
|
|
||||||
|
def add_vector_episode_statistics(
|
||||||
|
info: dict, episode_info: dict, num_envs: int, env_num: int
|
||||||
|
):
|
||||||
|
"""Add episode statistics.
|
||||||
|
|
||||||
|
Add statistics coming from the vectorized environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
info (dict): info dict of the environment.
|
||||||
|
episode_info (dict): episode statistics data.
|
||||||
|
num_envs (int): number of environments.
|
||||||
|
env_num (int): env number of the vectorized environments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
info (dict): the input info dict with the episode statistics.
|
||||||
|
"""
|
||||||
|
info["episode"] = info.get("episode", {})
|
||||||
|
|
||||||
|
info["_episode"] = info.get("_episode", np.zeros(num_envs, dtype=bool))
|
||||||
|
info["_episode"][env_num] = True
|
||||||
|
|
||||||
|
for k in episode_info.keys():
|
||||||
|
info_array = info["episode"].get(k, np.zeros(num_envs))
|
||||||
|
info_array[env_num] = episode_info[k]
|
||||||
|
info["episode"][k] = info_array
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
class RecordEpisodeStatistics(gym.Wrapper):
|
class RecordEpisodeStatistics(gym.Wrapper):
|
||||||
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
||||||
|
|
||||||
At the end of an episode, the statistics of the episode will be added to ``info``. After the completion
|
At the end of an episode, the statistics of the episode will be added to ``info``
|
||||||
of an episode, ``info`` will look like this::
|
using the key ``episode``. If using a vectorized environment also the key
|
||||||
|
``_episode`` is used which indicates whether the env at the respective index has
|
||||||
|
the episode statistics.
|
||||||
|
|
||||||
|
After the completion of an episode, ``info`` will look like this::
|
||||||
|
|
||||||
>>> info = {
|
>>> info = {
|
||||||
... ...
|
... ...
|
||||||
@@ -22,6 +55,18 @@ class RecordEpisodeStatistics(gym.Wrapper):
|
|||||||
... },
|
... },
|
||||||
... }
|
... }
|
||||||
|
|
||||||
|
For a vectorized environments the output will be in the form of::
|
||||||
|
|
||||||
|
>>> infos = {
|
||||||
|
... ...
|
||||||
|
... "episode": {
|
||||||
|
... "r": "<array of cumulative reward>",
|
||||||
|
... "l": "<array of episode length>",
|
||||||
|
... "t": "<array of elapsed time since instantiation of wrapper>"
|
||||||
|
... },
|
||||||
|
... "_episode": "<boolean array of length num-envs>"
|
||||||
|
... }
|
||||||
|
|
||||||
Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via
|
Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via
|
||||||
:attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively.
|
:attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively.
|
||||||
|
|
||||||
@@ -57,34 +102,40 @@ class RecordEpisodeStatistics(gym.Wrapper):
|
|||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""Steps through the environment, recording the episode statistics."""
|
"""Steps through the environment, recording the episode statistics."""
|
||||||
observations, rewards, dones, infos = super().step(action)
|
observations, rewards, dones, infos = super().step(action)
|
||||||
|
assert isinstance(
|
||||||
|
infos, dict
|
||||||
|
), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order."
|
||||||
self.episode_returns += rewards
|
self.episode_returns += rewards
|
||||||
self.episode_lengths += 1
|
self.episode_lengths += 1
|
||||||
if not self.is_vector_env:
|
if not self.is_vector_env:
|
||||||
infos = [infos]
|
|
||||||
dones = [dones]
|
dones = [dones]
|
||||||
else:
|
dones = list(dones)
|
||||||
infos = list(infos) # Convert infos to mutable type
|
|
||||||
for i in range(len(dones)):
|
for i in range(len(dones)):
|
||||||
if dones[i]:
|
if dones[i]:
|
||||||
infos[i] = infos[i].copy()
|
|
||||||
episode_return = self.episode_returns[i]
|
episode_return = self.episode_returns[i]
|
||||||
episode_length = self.episode_lengths[i]
|
episode_length = self.episode_lengths[i]
|
||||||
episode_info = {
|
episode_info = {
|
||||||
"r": episode_return,
|
"episode": {
|
||||||
"l": episode_length,
|
"r": episode_return,
|
||||||
"t": round(time.perf_counter() - self.t0, 6),
|
"l": episode_length,
|
||||||
|
"t": round(time.perf_counter() - self.t0, 6),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
infos[i]["episode"] = episode_info
|
if self.is_vector_env:
|
||||||
|
infos = add_vector_episode_statistics(
|
||||||
|
infos, episode_info["episode"], self.num_envs, i
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
infos = {**infos, **episode_info}
|
||||||
self.return_queue.append(episode_return)
|
self.return_queue.append(episode_return)
|
||||||
self.length_queue.append(episode_length)
|
self.length_queue.append(episode_length)
|
||||||
self.episode_count += 1
|
self.episode_count += 1
|
||||||
self.episode_returns[i] = 0
|
self.episode_returns[i] = 0
|
||||||
self.episode_lengths[i] = 0
|
self.episode_lengths[i] = 0
|
||||||
if self.is_vector_env:
|
|
||||||
infos = tuple(infos)
|
|
||||||
return (
|
return (
|
||||||
observations,
|
observations,
|
||||||
rewards,
|
rewards,
|
||||||
dones if self.is_vector_env else dones[0],
|
dones if self.is_vector_env else dones[0],
|
||||||
infos if self.is_vector_env else infos[0],
|
infos,
|
||||||
)
|
)
|
||||||
|
114
gym/wrappers/vector_list_info.py
Normal file
114
gym/wrappers/vector_list_info.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
"""Wrapper that converts the info format for vec envs into the list format."""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import gym
|
||||||
|
|
||||||
|
|
||||||
|
class VectorListInfo(gym.Wrapper):
|
||||||
|
"""Converts infos of vectorized envinroments from dict to List[dict].
|
||||||
|
|
||||||
|
This wrapper converts the info format of a
|
||||||
|
vector environment from a dictionary to a list of dictionaries.
|
||||||
|
This wrapper is intended to be used around vectorized
|
||||||
|
environments. If using other wrappers that perform
|
||||||
|
operation on info like `RecordEpisodeStatistics` this
|
||||||
|
need to be the outermost wrapper.
|
||||||
|
|
||||||
|
i.e. VectorListInfo(RecordEpisodeStatistics(envs))
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> # actual
|
||||||
|
>>> {
|
||||||
|
... "k": np.array[0., 0., 0.5, 0.3],
|
||||||
|
... "_k": np.array[False, False, True, True]
|
||||||
|
... }
|
||||||
|
>>> # classic
|
||||||
|
>>> [{}, {}, {k: 0.5}, {k: 0.3}]
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env):
|
||||||
|
"""This wrapper will convert the info into the list format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env (Env): The environment to apply the wrapper
|
||||||
|
"""
|
||||||
|
assert getattr(
|
||||||
|
env, "is_vector_env", False
|
||||||
|
), "This wrapper can only be used in vectorized environments."
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
"""Steps through the environment, convert dict info to list."""
|
||||||
|
observation, reward, done, infos = self.env.step(action)
|
||||||
|
list_info = self._convert_info_to_list(infos)
|
||||||
|
|
||||||
|
return observation, reward, done, list_info
|
||||||
|
|
||||||
|
def reset(self, **kwargs):
|
||||||
|
"""Resets the environment using kwargs."""
|
||||||
|
if not kwargs.get("return_info"):
|
||||||
|
return self.env.reset(**kwargs)
|
||||||
|
|
||||||
|
obs, infos = self.env.reset(**kwargs)
|
||||||
|
list_info = self._convert_info_to_list(infos)
|
||||||
|
return obs, list_info
|
||||||
|
|
||||||
|
def _convert_info_to_list(self, infos: dict) -> List[dict]:
|
||||||
|
"""Convert the dict info to list.
|
||||||
|
|
||||||
|
Convert the dict info of the vectorized environment
|
||||||
|
into a list of dictionaries where the i-th dictionary
|
||||||
|
has the info of the i-th environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
infos (dict): info dict coming from the env.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list_info (list): converted info.
|
||||||
|
|
||||||
|
"""
|
||||||
|
list_info = [{} for _ in range(self.num_envs)]
|
||||||
|
list_info = self._process_episode_statistics(infos, list_info)
|
||||||
|
for k in infos:
|
||||||
|
if k.startswith("_"):
|
||||||
|
continue
|
||||||
|
for i, has_info in enumerate(infos[f"_{k}"]):
|
||||||
|
if has_info:
|
||||||
|
list_info[i][k] = infos[k][i]
|
||||||
|
return list_info
|
||||||
|
|
||||||
|
def _process_episode_statistics(self, infos: dict, list_info: list) -> List[dict]:
|
||||||
|
"""Process episode statistics.
|
||||||
|
|
||||||
|
`RecordEpisodeStatistics` wrapper add extra
|
||||||
|
information to the info. This information are in
|
||||||
|
the form of a dict of dict. This method process these
|
||||||
|
information and add them to the info.
|
||||||
|
`RecordEpisodeStatistics` info contains the keys
|
||||||
|
"r", "l", "t" which represents "cumulative reward",
|
||||||
|
"episode length", "elapsed time since instantiation of wrapper".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
infos (dict): infos coming from `RecordEpisodeStatistics`.
|
||||||
|
list_info (list): info of the current vectorized environment.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list_info (list): updated info.
|
||||||
|
|
||||||
|
"""
|
||||||
|
episode_statistics = infos.pop("episode", False)
|
||||||
|
if not episode_statistics:
|
||||||
|
return list_info
|
||||||
|
|
||||||
|
episode_statistics_mask = infos.pop("_episode")
|
||||||
|
for i, has_info in enumerate(episode_statistics_mask):
|
||||||
|
if has_info:
|
||||||
|
list_info[i]["episode"] = {}
|
||||||
|
list_info[i]["episode"]["r"] = episode_statistics["r"][i]
|
||||||
|
list_info[i]["episode"]["l"] = episode_statistics["l"][i]
|
||||||
|
list_info[i]["episode"]["t"] = episode_statistics["t"][i]
|
||||||
|
|
||||||
|
return list_info
|
@@ -63,7 +63,7 @@ def test_reset_async_vector_env(shared_memory):
|
|||||||
assert observations.dtype == env.observation_space.dtype
|
assert observations.dtype == env.observation_space.dtype
|
||||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||||
assert observations.shape == env.observation_space.shape
|
assert observations.shape == env.observation_space.shape
|
||||||
assert isinstance(infos, list)
|
assert isinstance(infos, dict)
|
||||||
assert all([isinstance(info, dict) for info in infos])
|
assert all([isinstance(info, dict) for info in infos])
|
||||||
|
|
||||||
|
|
||||||
|
@@ -65,7 +65,7 @@ def test_reset_sync_vector_env():
|
|||||||
assert observations.dtype == env.observation_space.dtype
|
assert observations.dtype == env.observation_space.dtype
|
||||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||||
assert observations.shape == env.observation_space.shape
|
assert observations.shape == env.observation_space.shape
|
||||||
assert isinstance(infos, list)
|
assert isinstance(infos, dict)
|
||||||
assert all([isinstance(info, dict) for info in infos])
|
assert all([isinstance(info, dict) for info in infos])
|
||||||
|
|
||||||
|
|
||||||
|
@@ -35,11 +35,11 @@ def test_vector_env_equal(shared_memory):
|
|||||||
sync_observations, sync_rewards, sync_dones, sync_infos = sync_env.step(actions)
|
sync_observations, sync_rewards, sync_dones, sync_infos = sync_env.step(actions)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
for idx in range(len(sync_dones)):
|
if any(sync_dones):
|
||||||
if sync_dones[idx]:
|
assert "terminal_observation" in async_infos
|
||||||
assert "terminal_observation" in async_infos[idx]
|
assert "_terminal_observation" in async_infos
|
||||||
assert "terminal_observation" in sync_infos[idx]
|
assert "terminal_observation" in sync_infos
|
||||||
assert sync_dones[idx]
|
assert "_terminal_observation" in sync_infos
|
||||||
|
|
||||||
assert np.all(async_observations == sync_observations)
|
assert np.all(async_observations == sync_observations)
|
||||||
assert np.all(async_rewards == sync_rewards)
|
assert np.all(async_rewards == sync_rewards)
|
||||||
|
54
tests/vector/test_vector_env_info.py
Normal file
54
tests/vector/test_vector_env_info.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import gym
|
||||||
|
from gym.vector.sync_vector_env import SyncVectorEnv
|
||||||
|
from tests.vector.utils import make_env
|
||||||
|
|
||||||
|
ENV_ID = "CartPole-v1"
|
||||||
|
NUM_ENVS = 3
|
||||||
|
ENV_STEPS = 50
|
||||||
|
SEED = 42
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("asynchronous", [True, False])
|
||||||
|
def test_vector_env_info(asynchronous):
|
||||||
|
env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, asynchronous=asynchronous)
|
||||||
|
env.reset(seed=SEED)
|
||||||
|
for _ in range(ENV_STEPS):
|
||||||
|
env.action_space.seed(SEED)
|
||||||
|
action = env.action_space.sample()
|
||||||
|
_, _, dones, infos = env.step(action)
|
||||||
|
if any(dones):
|
||||||
|
assert len(infos["terminal_observation"]) == NUM_ENVS
|
||||||
|
assert len(infos["_terminal_observation"]) == NUM_ENVS
|
||||||
|
|
||||||
|
assert isinstance(infos["terminal_observation"], np.ndarray)
|
||||||
|
assert isinstance(infos["_terminal_observation"], np.ndarray)
|
||||||
|
|
||||||
|
for i, done in enumerate(dones):
|
||||||
|
if done:
|
||||||
|
assert infos["_terminal_observation"][i]
|
||||||
|
else:
|
||||||
|
assert not infos["_terminal_observation"][i]
|
||||||
|
assert infos["terminal_observation"][i] is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("concurrent_ends", [1, 2, 3])
|
||||||
|
def test_vector_env_info_concurrent_termination(concurrent_ends):
|
||||||
|
# envs that need to terminate together will have the same action
|
||||||
|
actions = [0] * concurrent_ends + [1] * (NUM_ENVS - concurrent_ends)
|
||||||
|
envs = [make_env(ENV_ID, SEED) for _ in range(NUM_ENVS)]
|
||||||
|
envs = SyncVectorEnv(envs)
|
||||||
|
|
||||||
|
for _ in range(ENV_STEPS):
|
||||||
|
_, _, dones, infos = envs.step(actions)
|
||||||
|
if any(dones):
|
||||||
|
for i, done in enumerate(dones):
|
||||||
|
if i < concurrent_ends:
|
||||||
|
assert done
|
||||||
|
assert infos["_terminal_observation"][i]
|
||||||
|
else:
|
||||||
|
assert not infos["_terminal_observation"][i]
|
||||||
|
assert infos["terminal_observation"][i] is None
|
||||||
|
return
|
@@ -1,7 +1,9 @@
|
|||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym.wrappers import RecordEpisodeStatistics
|
from gym.wrappers import RecordEpisodeStatistics, VectorListInfo
|
||||||
|
from gym.wrappers.record_episode_statistics import add_vector_episode_statistics
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
|
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
|
||||||
@@ -50,8 +52,47 @@ def test_record_episode_statistics_with_vectorenv(num_envs, asynchronous):
|
|||||||
envs.reset()
|
envs.reset()
|
||||||
for _ in range(max_episode_step + 1):
|
for _ in range(max_episode_step + 1):
|
||||||
_, _, dones, infos = envs.step(envs.action_space.sample())
|
_, _, dones, infos = envs.step(envs.action_space.sample())
|
||||||
for idx, info in enumerate(infos):
|
if any(dones):
|
||||||
if dones[idx]:
|
assert "episode" in infos
|
||||||
assert "episode" in info
|
assert "_episode" in infos
|
||||||
assert all([item in info["episode"] for item in ["r", "l", "t"]])
|
assert all(infos["_episode"] == dones)
|
||||||
break
|
assert all([item in infos["episode"] for item in ["r", "l", "t"]])
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
assert "episode" not in infos
|
||||||
|
assert "_episode" not in infos
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrong_wrapping_order():
|
||||||
|
envs = gym.vector.make("CartPole-v1", num_envs=3)
|
||||||
|
wrapped_env = RecordEpisodeStatistics(VectorListInfo(envs))
|
||||||
|
wrapped_env.reset()
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
wrapped_env.step(wrapped_env.action_space.sample())
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_vector_episode_statistics():
|
||||||
|
NUM_ENVS = 5
|
||||||
|
|
||||||
|
info = {}
|
||||||
|
for i in range(NUM_ENVS):
|
||||||
|
episode_info = {
|
||||||
|
"episode": {
|
||||||
|
"r": i,
|
||||||
|
"l": i,
|
||||||
|
"t": i,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
info = add_vector_episode_statistics(info, episode_info["episode"], NUM_ENVS, i)
|
||||||
|
assert np.alltrue(info["_episode"][: i + 1])
|
||||||
|
|
||||||
|
for j in range(NUM_ENVS):
|
||||||
|
if j <= i:
|
||||||
|
assert info["episode"]["r"][j] == j
|
||||||
|
assert info["episode"]["l"][j] == j
|
||||||
|
assert info["episode"]["t"][j] == j
|
||||||
|
else:
|
||||||
|
assert info["episode"]["r"][j] == 0
|
||||||
|
assert info["episode"]["l"][j] == 0
|
||||||
|
assert info["episode"]["t"][j] == 0
|
||||||
|
@@ -90,10 +90,11 @@ def test_record_video_within_vector():
|
|||||||
envs.reset()
|
envs.reset()
|
||||||
for i in range(199):
|
for i in range(199):
|
||||||
_, _, _, infos = envs.step(envs.action_space.sample())
|
_, _, _, infos = envs.step(envs.action_space.sample())
|
||||||
for info in infos:
|
|
||||||
if "episode" in info.keys():
|
# break when every env is done
|
||||||
print(f"episode_reward={info['episode']['r']}")
|
if "episode" in infos and all(infos["_episode"]):
|
||||||
break
|
print(f"episode_reward={infos['episode']['r']}")
|
||||||
|
|
||||||
assert os.path.isdir("videos")
|
assert os.path.isdir("videos")
|
||||||
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
|
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
|
||||||
assert len(mp4_files) == 2
|
assert len(mp4_files) == 2
|
||||||
|
58
tests/wrappers/test_vector_list_info.py
Normal file
58
tests/wrappers/test_vector_list_info.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
import gym
|
||||||
|
from gym.wrappers import RecordEpisodeStatistics, VectorListInfo
|
||||||
|
|
||||||
|
ENV_ID = "CartPole-v1"
|
||||||
|
NUM_ENVS = 3
|
||||||
|
ENV_STEPS = 50
|
||||||
|
SEED = 42
|
||||||
|
|
||||||
|
|
||||||
|
def test_usage_in_vector_env():
|
||||||
|
env = gym.make(ENV_ID)
|
||||||
|
vector_env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS)
|
||||||
|
|
||||||
|
VectorListInfo(vector_env)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
VectorListInfo(env)
|
||||||
|
|
||||||
|
|
||||||
|
def test_info_to_list():
|
||||||
|
env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS)
|
||||||
|
wrapped_env = VectorListInfo(env_to_wrap)
|
||||||
|
wrapped_env.action_space.seed(SEED)
|
||||||
|
_, info = wrapped_env.reset(seed=SEED, return_info=True)
|
||||||
|
assert isinstance(info, list)
|
||||||
|
assert len(info) == NUM_ENVS
|
||||||
|
|
||||||
|
for _ in range(ENV_STEPS):
|
||||||
|
action = wrapped_env.action_space.sample()
|
||||||
|
_, _, dones, list_info = wrapped_env.step(action)
|
||||||
|
for i, done in enumerate(dones):
|
||||||
|
if done:
|
||||||
|
assert "terminal_observation" in list_info[i]
|
||||||
|
else:
|
||||||
|
assert "terminal_observation" not in list_info[i]
|
||||||
|
|
||||||
|
|
||||||
|
def test_info_to_list_statistics():
|
||||||
|
env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS)
|
||||||
|
wrapped_env = VectorListInfo(RecordEpisodeStatistics(env_to_wrap))
|
||||||
|
_, info = wrapped_env.reset(seed=SEED, return_info=True)
|
||||||
|
wrapped_env.action_space.seed(SEED)
|
||||||
|
assert isinstance(info, list)
|
||||||
|
assert len(info) == NUM_ENVS
|
||||||
|
|
||||||
|
for _ in range(ENV_STEPS):
|
||||||
|
action = wrapped_env.action_space.sample()
|
||||||
|
_, _, dones, list_info = wrapped_env.step(action)
|
||||||
|
for i, done in enumerate(dones):
|
||||||
|
if done:
|
||||||
|
assert "episode" in list_info[i]
|
||||||
|
for stats in ["r", "l", "t"]:
|
||||||
|
assert stats in list_info[i]["episode"]
|
||||||
|
assert isinstance(list_info[i]["episode"][stats], float)
|
||||||
|
else:
|
||||||
|
assert "episode" not in list_info[i]
|
Reference in New Issue
Block a user