Change autoreset order (#808)

Co-authored-by: pseudo-rnd-thoughts <mark.m.towers@gmail.com>
This commit is contained in:
Ariel Kwiatkowski
2023-12-03 19:50:18 +01:00
committed by GitHub
parent 967bbf5823
commit e9c66e4225
18 changed files with 591 additions and 364 deletions

View File

@@ -151,6 +151,8 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
self.steps = jnp.zeros(self.num_envs, dtype=jnp.int32) self.steps = jnp.zeros(self.num_envs, dtype=jnp.int32)
self.autoreset_envs = jnp.zeros(self.num_envs, dtype=jnp.bool_)
self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box) self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
if self.render_mode == "rgb_array": if self.render_mode == "rgb_array":
@@ -214,10 +216,9 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
info = self.func_env.transition_info(self.state, action, next_state) info = self.func_env.transition_info(self.state, action, next_state)
done = jnp.logical_or(terminated, truncated) done = jnp.logical_or(terminated, truncated)
if jnp.any(done):
final_obs = self.func_env.observation(next_state)
to_reset = jnp.where(done)[0] if jnp.any(self.autoreset_envs):
to_reset = jnp.where(self.autoreset_envs)[0]
reset_count = to_reset.shape[0] reset_count = to_reset.shape[0]
rng, self.rng = jrng.split(self.rng) rng, self.rng = jrng.split(self.rng)
@@ -228,34 +229,16 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
next_state = self.state.at[to_reset].set(new_initials) next_state = self.state.at[to_reset].set(new_initials)
self.steps = self.steps.at[to_reset].set(0) self.steps = self.steps.at[to_reset].set(0)
# Get the final observations and infos self.autoreset_envs = done
info["final_observation"] = np.array([None for _ in range(self.num_envs)])
info["final_info"] = np.array([None for _ in range(self.num_envs)])
info["_final_observation"] = np.array([False for _ in range(self.num_envs)])
info["_final_info"] = np.array([False for _ in range(self.num_envs)])
# TODO: this can maybe be optimized, but right now I don't know how
for i in to_reset:
info["final_observation"][i] = final_obs[i]
info["final_info"][i] = {
k: v[i]
for k, v in info.items()
if k
not in {
"final_observation",
"final_info",
"_final_observation",
"_final_info",
}
}
info["_final_observation"][i] = True
info["_final_info"][i] = True
observation = self.func_env.observation(next_state) observation = self.func_env.observation(next_state)
observation = jax_to_numpy(observation) observation = jax_to_numpy(observation)
reward = jax_to_numpy(reward)
terminated = jax_to_numpy(terminated)
truncated = jax_to_numpy(truncated)
self.state = next_state self.state = next_state
return observation, reward, terminated, truncated, info return observation, reward, terminated, truncated, info

View File

@@ -639,6 +639,7 @@ def _async_worker(
env = env_fn() env = env_fn()
observation_space = env.observation_space observation_space = env.observation_space
action_space = env.action_space action_space = env.action_space
autoreset = False
parent_pipe.close() parent_pipe.close()
@@ -653,20 +654,21 @@ def _async_worker(
observation_space, index, observation, shared_memory observation_space, index, observation, shared_memory
) )
observation = None observation = None
autoreset = False
pipe.send(((observation, info), True)) pipe.send(((observation, info), True))
elif command == "step": elif command == "step":
( if autoreset:
observation,
reward,
terminated,
truncated,
info,
) = env.step(data)
if terminated or truncated:
old_observation, old_info = observation, info
observation, info = env.reset() observation, info = env.reset()
info["final_observation"] = old_observation reward, terminated, truncated = 0, False, False
info["final_info"] = old_info else:
(
observation,
reward,
terminated,
truncated,
info,
) = env.step(data)
autoreset = terminated or truncated
if shared_memory: if shared_memory:
write_to_shared_memory( write_to_shared_memory(

View File

@@ -98,6 +98,8 @@ class SyncVectorEnv(VectorEnv):
self._terminations = np.zeros((self.num_envs,), dtype=np.bool_) self._terminations = np.zeros((self.num_envs,), dtype=np.bool_)
self._truncations = np.zeros((self.num_envs,), dtype=np.bool_) self._truncations = np.zeros((self.num_envs,), dtype=np.bool_)
self._autoreset_envs = np.zeros((self.num_envs,), dtype=np.bool_)
def reset( def reset(
self, self,
*, *,
@@ -150,22 +152,21 @@ class SyncVectorEnv(VectorEnv):
actions = iterate(self.action_space, actions) actions = iterate(self.action_space, actions)
observations, infos = [], {} observations, infos = [], {}
for i, (env, action) in enumerate(zip(self.envs, actions)): for i, action in enumerate(actions):
( if self._autoreset_envs[i]:
env_obs, env_obs, env_info = self.envs[i].reset()
self._rewards[i],
self._terminations[i],
self._truncations[i],
env_info,
) = env.step(action)
# If sub-environments terminates or truncates then save the obs and info to the batched info self._rewards[i] = 0.0
if self._terminations[i] or self._truncations[i]: self._terminations[i] = False
old_observation, old_info = env_obs, env_info self._truncations[i] = False
env_obs, env_info = env.reset() else:
(
env_info["final_observation"] = old_observation env_obs,
env_info["final_info"] = old_info self._rewards[i],
self._terminations[i],
self._truncations[i],
env_info,
) = self.envs[i].step(action)
observations.append(env_obs) observations.append(env_obs)
infos = self._add_info(infos, env_info, i) infos = self._add_info(infos, env_info, i)
@@ -174,6 +175,7 @@ class SyncVectorEnv(VectorEnv):
self._observations = concatenate( self._observations = concatenate(
self.single_observation_space, observations, self._observations self.single_observation_space, observations, self._observations
) )
self._autoreset_envs = np.logical_or(self._terminations, self._truncations)
return ( return (
deepcopy(self._observations) if self.copy else self._observations, deepcopy(self._observations) if self.copy else self._observations,

View File

@@ -231,7 +231,7 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
return self return self
def _add_info( def _add_info(
self, infos: dict[str, Any], info: dict[str, Any], env_num: int self, vector_infos: dict[str, Any], env_info: dict[str, Any], env_num: int
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Add env info to the info dictionary of the vectorized environment. """Add env info to the info dictionary of the vectorized environment.
@@ -241,48 +241,51 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
whether or not the i-indexed environment has this `info`. whether or not the i-indexed environment has this `info`.
Args: Args:
infos (dict): the infos of the vectorized environment vector_infos (dict): the infos of the vectorized environment
info (dict): the info coming from the single environment env_info (dict): the info coming from the single environment
env_num (int): the index of the single environment env_num (int): the index of the single environment
Returns: Returns:
infos (dict): the (updated) infos of the vectorized environment infos (dict): the (updated) infos of the vectorized environment
""" """
for k in info.keys(): for key, value in env_info.items():
if k not in infos: # If value is a dictionary, then we apply the `_add_info` recursively.
info_array, array_mask = self._init_info_arrays(type(info[k])) if isinstance(value, dict):
array = self._add_info(vector_infos.get(key, {}), value, env_num)
# Otherwise, we are a base case to group the data
else: else:
info_array, array_mask = infos[k], infos[f"_{k}"] # If the key doesn't exist in the vector infos, then we can create an array of that batch type
if key not in vector_infos:
if type(value) in [int, float, bool] or issubclass(
type(value), np.number
):
array = np.zeros(self.num_envs, dtype=type(value))
elif isinstance(value, np.ndarray):
# We assume that all instances of the np.array info are of the same shape
array = np.zeros(
(self.num_envs, *value.shape), dtype=value.dtype
)
else:
# For unknown objects, we use a Numpy object array
array = np.full(self.num_envs, fill_value=None, dtype=object)
# Otherwise, just use the array that already exists
else:
array = vector_infos[key]
info_array[env_num], array_mask[env_num] = info[k], True # Assign the data in the `env_num` position
infos[k], infos[f"_{k}"] = info_array, array_mask # We only want to run this for the base-case data (not recursive data forcing the ugly function structure)
return infos array[env_num] = value
def _init_info_arrays(self, dtype: type) -> tuple[np.ndarray, np.ndarray]: # Get the array mask and if it doesn't already exist then create a zero bool array
"""Initialize the info array. array_mask = vector_infos.get(
f"_{key}", np.zeros(self.num_envs, dtype=np.bool_)
)
array_mask[env_num] = True
Initialize the info array. If the dtype is numeric # Update the vector info with the updated data and mask information
the info array will have the same dtype, otherwise vector_infos[key], vector_infos[f"_{key}"] = array, array_mask
will be an array of `None`. Also, a boolean array
of the same length is returned. It will be used for
assessing which environment has info data.
Args: return vector_infos
dtype (type): data type of the info coming from the env.
Returns:
array (np.ndarray): the initialized info array.
array_mask (np.ndarray): the initialized boolean array.
"""
if dtype in [int, float, bool] or issubclass(dtype, np.number):
array = np.zeros(self.num_envs, dtype=dtype)
else:
array = np.zeros(self.num_envs, dtype=object)
array[:] = None
array_mask = np.zeros(self.num_envs, dtype=bool)
return array, array_mask
def __del__(self): def __del__(self):
"""Closes the vector environment.""" """Closes the vector environment."""
@@ -441,23 +444,23 @@ class VectorObservationWrapper(VectorWrapper):
options: dict[str, Any] | None = None, options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]: ) -> tuple[ObsType, dict[str, Any]]:
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`.""" """Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
obs, info = self.env.reset(seed=seed, options=options) observations, infos = self.env.reset(seed=seed, options=options)
return self.vector_observation(obs), info return self.observation(observations), infos
def step( def step(
self, actions: ActType self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]: ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`.""" """Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
observation, reward, termination, truncation, info = self.env.step(actions) observations, rewards, terminations, truncations, infos = self.env.step(actions)
return ( return (
self.vector_observation(observation), self.observation(observations),
reward, rewards,
termination, terminations,
truncation, truncations,
self.update_final_obs(info), infos,
) )
def vector_observation(self, observation: ObsType) -> ObsType: def observation(self, observation: ObsType) -> ObsType:
"""Defines the vector observation transformation. """Defines the vector observation transformation.
Args: Args:
@@ -468,25 +471,6 @@ class VectorObservationWrapper(VectorWrapper):
""" """
raise NotImplementedError raise NotImplementedError
def single_observation(self, observation: ObsType) -> ObsType:
"""Defines the single observation transformation.
Args:
observation: A single observation from the environment
Returns:
The transformed observation
"""
raise NotImplementedError
def update_final_obs(self, info: dict[str, Any]) -> dict[str, Any]:
"""Updates the `final_obs` in the info using `single_observation`."""
if "final_observation" in info:
for i, obs in enumerate(info["final_observation"]):
if obs is not None:
info["final_observation"][i] = self.single_observation(obs)
return info
class VectorActionWrapper(VectorWrapper): class VectorActionWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the actions. """Wraps the vectorized environment to allow a modular transformation of the actions.
@@ -522,14 +506,14 @@ class VectorRewardWrapper(VectorWrapper):
self, actions: ActType self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]: ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
"""Steps through the environment returning a reward modified by :meth:`reward`.""" """Steps through the environment returning a reward modified by :meth:`reward`."""
observation, reward, termination, truncation, info = self.env.step(actions) observations, rewards, terminations, truncations, infos = self.env.step(actions)
return observation, self.rewards(reward), termination, truncation, info return observations, self.rewards(rewards), terminations, truncations, infos
def rewards(self, reward: ArrayType) -> ArrayType: def rewards(self, rewards: ArrayType) -> ArrayType:
"""Transform the reward before returning it. """Transform the reward before returning it.
Args: Args:
reward (array): the reward to transform rewards (array): the reward to transform
Returns: Returns:
array: the transformed reward array: the transformed reward

View File

@@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Any, SupportsFloat
import gymnasium as gym import gymnasium as gym
from gymnasium import logger from gymnasium import logger
from gymnasium.core import ActType, ObsType, RenderFrame from gymnasium.core import ActType, ObsType, RenderFrame, WrapperObsType
from gymnasium.error import ResetNeeded from gymnasium.error import ResetNeeded
from gymnasium.utils.passive_env_checker import ( from gymnasium.utils.passive_env_checker import (
check_action_space, check_action_space,
@@ -196,6 +196,15 @@ class Autoreset(
gym.utils.RecordConstructorArgs.__init__(self) gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env) gym.Wrapper.__init__(self, env)
self.autoreset = False
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment and sets autoreset to False preventing."""
self.autoreset = False
return super().reset(seed=seed, options=options)
def step( def step(
self, action: ActType self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
@@ -207,24 +216,13 @@ class Autoreset(
Returns: Returns:
The autoreset environment :meth:`step` The autoreset environment :meth:`step`
""" """
obs, reward, terminated, truncated, info = self.env.step(action) if self.autoreset:
obs, info = self.env.reset()
if terminated or truncated: reward, terminated, truncated = 0.0, False, False
new_obs, new_info = self.env.reset() else:
obs, reward, terminated, truncated, info = self.env.step(action)
assert (
"final_observation" not in new_info
), f'new info dict already contains "final_observation", info keys: {new_info.keys()}'
assert (
"final_info" not in new_info
), f'new info dict already contains "final_observation", info keys: {new_info.keys()}'
new_info["final_observation"] = obs
new_info["final_info"] = info
obs = new_obs
info = new_info
self.autoreset = terminated or truncated
return obs, reward, terminated, truncated, info return obs, reward, terminated, truncated, info
@@ -470,14 +468,14 @@ class RecordEpisodeStatistics(
def __init__( def __init__(
self, self,
env: gym.Env[ObsType, ActType], env: gym.Env[ObsType, ActType],
buffer_length: int | None = 100, buffer_length: int = 100,
stats_key: str = "episode", stats_key: str = "episode",
): ):
"""This wrapper will keep track of cumulative rewards and episode lengths. """This wrapper will keep track of cumulative rewards and episode lengths.
Args: Args:
env (Env): The environment to apply the wrapper env (Env): The environment to apply the wrapper
buffer_length: The size of the buffers :attr:`return_queue` and :attr:`length_queue` buffer_length: The size of the buffers :attr:`return_queue`, :attr:`length_queue` and :attr:`time_queue`
stats_key: The info key for the episode statistics stats_key: The info key for the episode statistics
""" """
gym.utils.RecordConstructorArgs.__init__(self) gym.utils.RecordConstructorArgs.__init__(self)
@@ -520,6 +518,7 @@ class RecordEpisodeStatistics(
self.length_queue.append(self.episode_lengths) self.length_queue.append(self.episode_lengths)
self.episode_count += 1 self.episode_count += 1
self.episode_start_time = time.perf_counter()
return obs, reward, terminated, truncated, info return obs, reward, terminated, truncated, info

View File

@@ -444,7 +444,7 @@ class NormalizeObservation(
Change logs: Change logs:
* v0.21.0 - Initially add * v0.21.0 - Initially add
* v1.0.0 - Add `update_running_mean` attribute to allow disabling of updating the running mean / standard * v1.0.0 - Add `update_running_mean` attribute to allow disabling of updating the running mean / standard, particularly useful for evaluation time.
""" """
def __init__(self, env: gym.Env[ObsType, ActType], epsilon: float = 1e-8): def __init__(self, env: gym.Env[ObsType, ActType], epsilon: float = 1e-8):

View File

@@ -62,14 +62,21 @@ class RecordEpisodeStatistics(VectorWrapper):
None, None], dtype=object)} None, None], dtype=object)}
""" """
def __init__(self, env: VectorEnv, deque_size: int = 100): def __init__(
self,
env: VectorEnv,
deque_size: int = 100,
stats_key: str = "episode",
):
"""This wrapper will keep track of cumulative rewards and episode lengths. """This wrapper will keep track of cumulative rewards and episode lengths.
Args: Args:
env (Env): The environment to apply the wrapper env (Env): The environment to apply the wrapper
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue` deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
stats_key: The info key to save the data
""" """
super().__init__(env) super().__init__(env)
self._stats_key = stats_key
self.episode_count = 0 self.episode_count = 0
@@ -77,6 +84,7 @@ class RecordEpisodeStatistics(VectorWrapper):
self.episode_returns: np.ndarray = np.zeros(()) self.episode_returns: np.ndarray = np.zeros(())
self.episode_lengths: np.ndarray = np.zeros(()) self.episode_lengths: np.ndarray = np.zeros(())
self.time_queue = deque(maxlen=deque_size)
self.return_queue = deque(maxlen=deque_size) self.return_queue = deque(maxlen=deque_size)
self.length_queue = deque(maxlen=deque_size) self.length_queue = deque(maxlen=deque_size)
@@ -88,11 +96,9 @@ class RecordEpisodeStatistics(VectorWrapper):
"""Resets the environment using kwargs and resets the episode returns and lengths.""" """Resets the environment using kwargs and resets the episode returns and lengths."""
obs, info = super().reset(seed=seed, options=options) obs, info = super().reset(seed=seed, options=options)
self.episode_start_times = np.full( self.episode_start_times = np.full(self.num_envs, time.perf_counter())
self.num_envs, time.perf_counter(), dtype=np.float32 self.episode_returns = np.zeros(self.num_envs)
) self.episode_lengths = np.zeros(self.num_envs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return obs, info return obs, info
@@ -110,7 +116,7 @@ class RecordEpisodeStatistics(VectorWrapper):
assert isinstance( assert isinstance(
infos, dict infos, dict
), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order." ), f"`vector.RecordEpisodeStatistics` requires `info` type to be `dict`, its actual type is {type(infos)}. This may be due to usage of other wrappers in the wrong order."
self.episode_returns += rewards self.episode_returns += rewards
self.episode_lengths += 1 self.episode_lengths += 1
@@ -119,25 +125,25 @@ class RecordEpisodeStatistics(VectorWrapper):
num_dones = np.sum(dones) num_dones = np.sum(dones)
if num_dones: if num_dones:
if "episode" in infos or "_episode" in infos: if self._stats_key in infos or f"_{self._stats_key}" in infos:
raise ValueError( raise ValueError(
"Attempted to add episode stats when they already exist" f"Attempted to add episode stats when they already exist, info keys: {list(infos.keys())}"
) )
else: else:
infos["episode"] = { episode_time_length = np.round(
time.perf_counter() - self.episode_start_times, 6
)
infos[self._stats_key] = {
"r": np.where(dones, self.episode_returns, 0.0), "r": np.where(dones, self.episode_returns, 0.0),
"l": np.where(dones, self.episode_lengths, 0), "l": np.where(dones, self.episode_lengths, 0),
"t": np.where( "t": np.where(dones, episode_time_length, 0.0),
dones,
np.round(time.perf_counter() - self.episode_start_times, 6),
0.0,
),
} }
infos["_episode"] = dones infos[f"_{self._stats_key}"] = dones
self.episode_count += num_dones self.episode_count += num_dones
for i in np.where(dones): for i in np.where(dones):
self.time_queue.extend(episode_time_length[i])
self.return_queue.extend(self.episode_returns[i]) self.return_queue.extend(self.episode_returns[i])
self.length_queue.extend(self.episode_lengths[i]) self.length_queue.extend(self.episode_lengths[i])

View File

@@ -3,6 +3,8 @@ from __future__ import annotations
from typing import Any from typing import Any
import numpy as np
from gymnasium.core import ActType, ObsType from gymnasium.core import ActType, ObsType
from gymnasium.vector.vector_env import ArrayType, VectorEnv, VectorWrapper from gymnasium.vector.vector_env import ArrayType, VectorEnv, VectorWrapper
@@ -78,6 +80,7 @@ class DictInfoToList(VectorWrapper):
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, list[dict[str, Any]]]: ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, list[dict[str, Any]]]:
"""Steps through the environment, convert dict info to list.""" """Steps through the environment, convert dict info to list."""
observation, reward, terminated, truncated, infos = self.env.step(actions) observation, reward, terminated, truncated, infos = self.env.step(actions)
assert isinstance(infos, dict)
list_info = self._convert_info_to_list(infos) list_info = self._convert_info_to_list(infos)
return observation, reward, terminated, truncated, list_info return observation, reward, terminated, truncated, list_info
@@ -90,11 +93,12 @@ class DictInfoToList(VectorWrapper):
) -> tuple[ObsType, list[dict[str, Any]]]: ) -> tuple[ObsType, list[dict[str, Any]]]:
"""Resets the environment using kwargs.""" """Resets the environment using kwargs."""
obs, infos = self.env.reset(seed=seed, options=options) obs, infos = self.env.reset(seed=seed, options=options)
assert isinstance(infos, dict)
list_info = self._convert_info_to_list(infos) list_info = self._convert_info_to_list(infos)
return obs, list_info return obs, list_info
def _convert_info_to_list(self, infos: dict) -> list[dict[str, Any]]: def _convert_info_to_list(self, vector_infos: dict) -> list[dict[str, Any]]:
"""Convert the dict info to list. """Convert the dict info to list.
Convert the dict info of the vectorized environment Convert the dict info of the vectorized environment
@@ -102,52 +106,28 @@ class DictInfoToList(VectorWrapper):
has the info of the i-th environment. has the info of the i-th environment.
Args: Args:
infos (dict): info dict coming from the env. vector_infos (dict): info dict coming from the env.
Returns: Returns:
list_info (list): converted info. list_info (list): converted info.
""" """
list_info = [{} for _ in range(self.num_envs)] list_info = [{} for _ in range(self.num_envs)]
list_info = self._process_episode_statistics(infos, list_info)
for k in infos: for key, value in vector_infos.items():
if k.startswith("_"): if key.startswith("_"):
continue continue
for i, has_info in enumerate(infos[f"_{k}"]):
if has_info:
list_info[i][k] = infos[k][i]
return list_info
# todo - I think this function should be more general for any information if isinstance(value, dict):
def _process_episode_statistics(self, infos: dict, list_info: list) -> list[dict]: value_list_info = self._convert_info_to_list(value)
"""Process episode statistics. for env_num, (env_info, has_info) in enumerate(
zip(value_list_info, vector_infos[f"_{key}"])
`RecordEpisodeStatistics` wrapper add extra ):
information to the info. This information are in if has_info:
the form of a dict of dict. This method process these list_info[env_num][key] = env_info
information and add them to the info. else:
`RecordEpisodeStatistics` info contains the keys assert isinstance(value, np.ndarray)
"r", "l", "t" which represents "cumulative reward", for env_num, has_info in enumerate(vector_infos[f"_{key}"]):
"episode length", "elapsed time since instantiation of wrapper". if has_info:
list_info[env_num][key] = value[env_num]
Args:
infos (dict): infos coming from `RecordEpisodeStatistics`.
list_info (list): info of the current vectorized environment.
Returns:
list_info (list): updated info.
"""
episode_statistics = infos.pop("episode", False)
if not episode_statistics:
return list_info
episode_statistics_mask = infos.pop("_episode")
for i, has_info in enumerate(episode_statistics_mask):
if has_info:
list_info[i]["episode"] = {}
list_info[i]["episode"]["r"] = episode_statistics["r"][i]
list_info[i]["episode"]["l"] = episode_statistics["l"][i]
list_info[i]["episode"]["t"] = episode_statistics["t"][i]
return list_info return list_info

View File

@@ -34,9 +34,9 @@ class NormalizeObservation(VectorObservationWrapper, gym.utils.RecordConstructor
>>> for _ in range(100): >>> for _ in range(100):
... obs, *_ = envs.step(envs.action_space.sample()) ... obs, *_ = envs.step(envs.action_space.sample())
>>> np.mean(obs) >>> np.mean(obs)
-0.017698428 0.024251968
>>> np.std(obs) >>> np.std(obs)
0.62041104 0.62259156
>>> envs.close() >>> envs.close()
Example with the normalize reward wrapper: Example with the normalize reward wrapper:
@@ -48,9 +48,9 @@ class NormalizeObservation(VectorObservationWrapper, gym.utils.RecordConstructor
>>> for _ in range(100): >>> for _ in range(100):
... obs, *_ = envs.step(envs.action_space.sample()) ... obs, *_ = envs.step(envs.action_space.sample())
>>> np.mean(obs) >>> np.mean(obs)
-0.28381696 -0.2359734
>>> np.std(obs) >>> np.std(obs)
1.21742 1.1938739
>>> envs.close() >>> envs.close()
""" """
@@ -81,29 +81,15 @@ class NormalizeObservation(VectorObservationWrapper, gym.utils.RecordConstructor
"""Sets the property to freeze/continue the running mean calculation of the observation statistics.""" """Sets the property to freeze/continue the running mean calculation of the observation statistics."""
self._update_running_mean = setting self._update_running_mean = setting
def vector_observation(self, observation: ObsType) -> ObsType: def observation(self, observations: ObsType) -> ObsType:
"""Defines the vector observation normalization function. """Defines the vector observation normalization function.
Args: Args:
observation: A vector observation from the environment observations: A vector observation from the environment
Returns: Returns:
the normalized observation the normalized observation
""" """
return self._normalize_observations(observation)
def single_observation(self, observation: ObsType) -> ObsType:
"""Defines the single observation normalization function.
Args:
observation: A single observation from the environment
Returns:
The normalized observation
"""
return self._normalize_observations(observation[None])
def _normalize_observations(self, observations: ObsType) -> ObsType:
if self._update_running_mean: if self._update_running_mean:
self.obs_rms.update(observations) self.obs_rms.update(observations)
return (observations - self.obs_rms.mean) / np.sqrt( return (observations - self.obs_rms.mean) / np.sqrt(

View File

@@ -37,13 +37,10 @@ class TransformObservation(VectorObservationWrapper):
>>> def scale_and_shift(obs): >>> def scale_and_shift(obs):
... return (obs - 1.0) * 2.0 ... return (obs - 1.0) * 2.0
... ...
>>> def vector_scale_and_shift(obs):
... return (obs - 1.0) * 2.0
...
>>> import gymnasium as gym >>> import gymnasium as gym
>>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync") >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
>>> new_obs_space = Box(low=envs.observation_space.low, high=envs.observation_space.high) >>> new_obs_space = Box(low=envs.observation_space.low, high=envs.observation_space.high)
>>> envs = TransformObservation(envs, single_func=scale_and_shift, vector_func=vector_scale_and_shift) >>> envs = TransformObservation(envs, func=scale_and_shift, observation_space=new_obs_space)
>>> obs, info = envs.reset(seed=123) >>> obs, info = envs.reset(seed=123)
>>> obs >>> obs
array([[-1.9635296, -2.0892358, -2.055928 , -2.0631256], array([[-1.9635296, -2.0892358, -2.055928 , -2.0631256],
@@ -55,16 +52,14 @@ class TransformObservation(VectorObservationWrapper):
def __init__( def __init__(
self, self,
env: VectorEnv, env: VectorEnv,
vector_func: Callable[[ObsType], Any], func: Callable[[ObsType], Any],
single_func: Callable[[ObsType], Any],
observation_space: Space | None = None, observation_space: Space | None = None,
): ):
"""Constructor for the transform observation wrapper. """Constructor for the transform observation wrapper.
Args: Args:
env: The vector environment to wrap env: The vector environment to wrap
vector_func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``. func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``.
single_func: A function that will transform an individual observation, this function will be used for the final observation from the environment and is returned under ``info`` and not the normal observation.
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``. observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``.
""" """
super().__init__(env) super().__init__(env)
@@ -72,16 +67,11 @@ class TransformObservation(VectorObservationWrapper):
if observation_space is not None: if observation_space is not None:
self.observation_space = observation_space self.observation_space = observation_space
self.vector_func = vector_func self.func = func
self.single_func = single_func
def vector_observation(self, observation: ObsType) -> ObsType: def observation(self, observations: ObsType) -> ObsType:
"""Apply function to the vector observation.""" """Apply function to the vector observation."""
return self.vector_func(observation) return self.func(observations)
def single_observation(self, observation: ObsType) -> ObsType:
"""Apply function to the single observation."""
return self.single_func(observation)
class VectorizeTransformObservation(VectorObservationWrapper): class VectorizeTransformObservation(VectorObservationWrapper):
@@ -158,16 +148,16 @@ class VectorizeTransformObservation(VectorObservationWrapper):
self.same_out = self.observation_space == self.env.observation_space self.same_out = self.observation_space == self.env.observation_space
self.out = create_empty_array(self.single_observation_space, self.num_envs) self.out = create_empty_array(self.single_observation_space, self.num_envs)
def vector_observation(self, observation: ObsType) -> ObsType: def observation(self, observations: ObsType) -> ObsType:
"""Iterates over the vector observations applying the single-agent wrapper ``observation`` then concatenates the observations together again.""" """Iterates over the vector observations applying the single-agent wrapper ``observation`` then concatenates the observations together again."""
if self.same_out: if self.same_out:
return concatenate( return concatenate(
self.single_observation_space, self.single_observation_space,
tuple( tuple(
self.wrapper.func(obs) self.wrapper.func(obs)
for obs in iterate(self.observation_space, observation) for obs in iterate(self.observation_space, observations)
), ),
observation, observations,
) )
else: else:
return deepcopy( return deepcopy(
@@ -175,16 +165,12 @@ class VectorizeTransformObservation(VectorObservationWrapper):
self.single_observation_space, self.single_observation_space,
tuple( tuple(
self.wrapper.func(obs) self.wrapper.func(obs)
for obs in iterate(self.env.observation_space, observation) for obs in iterate(self.env.observation_space, observations)
), ),
self.out, self.out,
) )
) )
def single_observation(self, observation: ObsType) -> ObsType:
"""Transforms a single observation using the wrapper transformation function."""
return self.wrapper.func(observation)
class FilterObservation(VectorizeTransformObservation): class FilterObservation(VectorizeTransformObservation):
"""Vector wrapper for filtering dict or tuple observation spaces. """Vector wrapper for filtering dict or tuple observation spaces.

View File

@@ -6,8 +6,14 @@ import jax.numpy as jnp # noqa: E402
import jax.random as jrng # noqa: E402 import jax.random as jrng # noqa: E402
import numpy as np # noqa: E402 import numpy as np # noqa: E402
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional # noqa: E402 from gymnasium.envs.phys2d.cartpole import ( # noqa: E402
from gymnasium.envs.phys2d.pendulum import PendulumFunctional # noqa: E402 CartPoleFunctional,
CartPoleJaxVectorEnv,
)
from gymnasium.envs.phys2d.pendulum import ( # noqa: E402
PendulumFunctional,
PendulumJaxVectorEnv,
)
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional]) @pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
@@ -105,3 +111,34 @@ def test_vmap(env_class):
assert obs.dtype == jnp.float32 assert obs.dtype == jnp.float32
state = next_state state = next_state
@pytest.mark.parametrize("env_class", [CartPoleJaxVectorEnv, PendulumJaxVectorEnv])
def test_vectorized(env_class):
env = env_class(num_envs=10)
env.action_space.seed(0)
obs, info = env.reset(seed=0)
assert obs.shape == (10,) + env.single_observation_space.shape
assert isinstance(obs, np.ndarray)
assert isinstance(info, dict)
for t in range(100):
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
assert obs.shape == (10,) + env.single_observation_space.shape
assert isinstance(obs, np.ndarray)
assert reward.shape == (10,)
assert isinstance(reward, np.ndarray)
assert terminated.shape == (10,)
assert isinstance(terminated, np.ndarray)
assert truncated.shape == (10,)
assert isinstance(truncated, np.ndarray)
assert isinstance(info, dict)
# These were removed in the new autoreset order
assert "final_observation" not in info
assert "final_info" not in info
assert "_final_observation" not in info
assert "_final_info" not in info

View File

@@ -33,7 +33,7 @@ def test_create_async_vector_env(shared_memory):
@pytest.mark.parametrize("shared_memory", [True, False]) @pytest.mark.parametrize("shared_memory", [True, False])
def test_reset_async_vector_env(shared_memory): def test_reset_async_vector_env(shared_memory):
"""Test the reset of an sync vector environment with or without shared memory.""" """Test the reset of async vector environment with or without shared memory."""
env_fns = [make_env("CartPole-v1", i) for i in range(8)] env_fns = [make_env("CartPole-v1", i) for i in range(8)]
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)

View File

@@ -6,6 +6,7 @@ import numpy as np
import pytest import pytest
from gymnasium.spaces import Discrete from gymnasium.spaces import Discrete
from gymnasium.utils.env_checker import data_equivalence
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
from tests.testing_env import GenericTestEnv from tests.testing_env import GenericTestEnv
from tests.vector.testing_utils import make_env from tests.vector.testing_utils import make_env
@@ -29,6 +30,7 @@ def test_vector_env_equal(shared_memory):
async_observations, async_infos = async_env.reset(seed=0) async_observations, async_infos = async_env.reset(seed=0)
sync_observations, sync_infos = sync_env.reset(seed=0) sync_observations, sync_infos = sync_env.reset(seed=0)
assert np.all(async_observations == sync_observations) assert np.all(async_observations == sync_observations)
assert data_equivalence(async_infos, sync_infos)
for _ in range(num_steps): for _ in range(num_steps):
actions = async_env.action_space.sample() actions = async_env.action_space.sample()
@@ -49,16 +51,11 @@ def test_vector_env_equal(shared_memory):
sync_infos, sync_infos,
) = sync_env.step(actions) ) = sync_env.step(actions)
if any(sync_terminations) or any(sync_truncations):
assert "final_observation" in async_infos
assert "_final_observation" in async_infos
assert "final_observation" in sync_infos
assert "_final_observation" in sync_infos
assert np.all(async_observations == sync_observations) assert np.all(async_observations == sync_observations)
assert np.all(async_rewards == sync_rewards) assert np.all(async_rewards == sync_rewards)
assert np.all(async_terminations == sync_terminations) assert np.all(async_terminations == sync_terminations)
assert np.all(async_truncations == sync_truncations) assert np.all(async_truncations == sync_truncations)
assert data_equivalence(async_infos, sync_infos)
async_env.close() async_env.close()
sync_env.close() sync_env.close()
@@ -115,14 +112,13 @@ def test_final_obs_info(vectoriser):
) )
obs, _, termination, _, info = env.step([3]) obs, _, termination, _, info = env.step([3])
assert obs == np.array([0]) and info == {"action": 3, "_action": np.array([True])}
obs, _, terminated, _, info = env.step([4])
assert ( assert (
obs == np.array([0]) obs == np.array([0])
and termination == np.array([True]) and termination == np.array([True])
and info["reset"] == np.array([True]) and info["reset"] == np.array([True])
) )
assert "final_observation" in info and "final_info" in info
assert info["final_observation"] == np.array([0]) and info["final_info"] == {
"action": 3
}
env.close() env.close()

View File

@@ -1,66 +1,164 @@
"""Test the vector environment information.""" """Test the vector environment information."""
from __future__ import annotations
from typing import Any, SupportsFloat
import numpy as np import numpy as np
import pytest import pytest
import gymnasium as gym import gymnasium as gym
from gymnasium.vector.sync_vector_env import SyncVectorEnv from gymnasium.core import ActType, ObsType
from tests.vector.testing_utils import make_env from gymnasium.spaces import Box, Discrete
from gymnasium.utils.env_checker import data_equivalence
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv, VectorEnv
ENV_ID = "CartPole-v1" def test_vector_add_info():
NUM_ENVS = 3 env = VectorEnv()
ENV_STEPS = 50
SEED = 42 # Test num-envs==1 then expand_dims(sub-env-info) == vector-infos
env.num_envs = 1
sub_env_info = {"a": 0, "b": 0.0, "c": None, "d": np.zeros((2,)), "e": Discrete(1)}
vector_infos = env._add_info({}, sub_env_info, 0)
expected_vector_infos = {
"a": np.array([0]),
"b": np.array([0.0]),
"c": np.array([None], dtype=object),
"d": np.zeros(
(
1,
2,
)
),
"e": np.array([Discrete(1)], dtype=object),
"_a": np.array([True]),
"_b": np.array([True]),
"_c": np.array([True]),
"_d": np.array([True]),
"_e": np.array([True]),
}
assert data_equivalence(vector_infos, expected_vector_infos)
# Thought: num-envs>1 then vector-infos should have the same structure as sub-env-info
env.num_envs = 3
sub_env_infos = [
{"a": 0, "b": 0.0, "c": None, "d": np.zeros((2,)), "e": Discrete(1)},
{"a": 1, "b": 1.0, "c": None, "d": np.zeros((2,)), "e": Discrete(2)},
{"a": 2, "b": 2.0, "c": None, "d": np.zeros((2,)), "e": Discrete(3)},
]
vector_infos = {}
for i, info in enumerate(sub_env_infos):
vector_infos = env._add_info(vector_infos, info, i)
expected_vector_infos = {
"a": np.array([0, 1, 2]),
"b": np.array([0.0, 1.0, 2.0]),
"c": np.array([None, None, None], dtype=object),
"d": np.zeros((3, 2)),
"e": np.array([Discrete(1), Discrete(2), Discrete(3)], dtype=object),
"_a": np.array([True, True, True]),
"_b": np.array([True, True, True]),
"_c": np.array([True, True, True]),
"_d": np.array([True, True, True]),
"_e": np.array([True, True, True]),
}
assert data_equivalence(vector_infos, expected_vector_infos)
# Test different structures of sub-infos
env.num_envs = 3
sub_env_infos = [
{"a": 1, "b": 1.0},
{"c": None, "d": np.zeros((2,))},
{"e": Discrete(3)},
]
vector_infos = {}
for i, info in enumerate(sub_env_infos):
vector_infos = env._add_info(vector_infos, info, i)
expected_vector_infos = {
"a": np.array([1, 0, 0]),
"b": np.array([1.0, 0.0, 0.0]),
"c": np.array([None, None, None], dtype=object),
"d": np.zeros((3, 2)),
"e": np.array([None, None, Discrete(3)], dtype=object),
"_a": np.array([True, False, False]),
"_b": np.array([True, False, False]),
"_c": np.array([False, True, False]),
"_d": np.array([False, True, False]),
"_e": np.array([False, False, True]),
}
assert data_equivalence(vector_infos, expected_vector_infos)
# Test recursive structure
env.num_envs = 3
sub_env_infos = [
{"episode": {"a": 1, "b": 1.0}},
{"episode": {"a": 2, "b": 2.0}, "a": 1},
{"a": 2},
]
vector_infos = {}
for i, info in enumerate(sub_env_infos):
vector_infos = env._add_info(vector_infos, info, i)
expected_vector_infos = {
"episode": {
"a": np.array([1, 2, 0]),
"b": np.array([1.0, 2.0, 0.0]),
"_a": np.array([True, True, False]),
"_b": np.array([True, True, False]),
},
"_episode": np.array([True, True, False]),
"a": np.array([0, 1, 2]),
"_a": np.array([False, True, True]),
}
assert data_equivalence(vector_infos, expected_vector_infos)
@pytest.mark.parametrize("vectorization_mode", ["async", "sync"]) class ReturnInfoEnv(gym.Env):
def test_vector_env_info(vectorization_mode: str): def __init__(self, infos):
"""Test vector environment info for different vectorization modes.""" self.observation_space = Box(0, 1)
env = gym.make_vec( self.action_space = Box(0, 1)
ENV_ID,
num_envs=NUM_ENVS, self.infos = infos
vectorization_mode=vectorization_mode,
def reset(
self,
*,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
return self.observation_space.sample(), self.infos[0]
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
return self.observation_space.sample(), 0, True, False, self.infos[1]
@pytest.mark.parametrize("vectorizer", [AsyncVectorEnv, SyncVectorEnv])
def test_vectorizers(vectorizer):
vec_env = vectorizer(
[
lambda: ReturnInfoEnv([{"a": 1}, {"c": np.array([1, 2])}]),
lambda: ReturnInfoEnv([{"a": 2, "b": 3}, {"c": np.array([3, 4])}]),
]
) )
env.reset(seed=SEED)
for _ in range(ENV_STEPS):
env.action_space.seed(SEED)
action = env.action_space.sample()
_, _, terminateds, truncateds, infos = env.step(action)
if any(terminateds) or any(truncateds):
assert len(infos["final_observation"]) == NUM_ENVS
assert len(infos["_final_observation"]) == NUM_ENVS
assert isinstance(infos["final_observation"], np.ndarray) reset_expected_infos = {
assert isinstance(infos["_final_observation"], np.ndarray) "a": np.array([1, 2]),
"b": np.array([0, 3]),
"_a": np.array([True, True]),
"_b": np.array([False, True]),
}
step_expected_infos = {
"c": np.array([[1, 2], [3, 4]]),
"_c": np.array([True, True]),
}
for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)): _, reset_info = vec_env.reset()
if terminated or truncated: assert data_equivalence(reset_info, reset_expected_infos)
assert infos["_final_observation"][i] _, _, _, _, step_info = vec_env.step(vec_env.action_space.sample())
else: assert data_equivalence(step_info, step_expected_infos)
assert not infos["_final_observation"][i]
assert infos["final_observation"][i] is None
env.close()
@pytest.mark.parametrize("concurrent_ends", [1, 2, 3])
def test_vector_env_info_concurrent_termination(concurrent_ends):
"""Test the vector environment information works with concurrent termination."""
# envs that need to terminate together will have the same action
actions = [0] * concurrent_ends + [1] * (NUM_ENVS - concurrent_ends)
envs = [make_env(ENV_ID, SEED) for _ in range(NUM_ENVS)]
envs = SyncVectorEnv(envs)
for _ in range(ENV_STEPS):
_, _, terminateds, truncateds, infos = envs.step(actions)
if any(terminateds) or any(truncateds):
for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)):
if i < concurrent_ends:
assert terminated or truncated
assert infos["_final_observation"][i]
else:
assert not infos["_final_observation"][i]
assert infos["final_observation"][i] is None
return
envs.close()

View File

@@ -45,19 +45,15 @@ def test_autoreset_wrapper_autoreset():
assert info == {"count": 2} assert info == {"count": 2}
obs, reward, terminated, truncated, info = env.step(action) obs, reward, terminated, truncated, info = env.step(action)
assert obs == np.array([0]) assert obs == np.array([3])
assert (terminated or truncated) is True assert (terminated or truncated) is True
assert reward == 1 assert reward == 1
assert info == { assert info == {"count": 3}
"count": 0,
"final_observation": np.array([3]),
"final_info": {"count": 3},
}
obs, reward, terminated, truncated, info = env.step(action) obs, reward, terminated, truncated, info = env.step(action)
assert obs == np.array([1]) assert obs == np.array([0])
assert reward == 0 assert reward == 0
assert (terminated or truncated) is False assert (terminated or truncated) is False
assert info == {"count": 1} assert info == {"count": 0}
env.close() env.close()

View File

@@ -1,21 +1,22 @@
"""Test suite for DictInfoTolist wrapper.""" """Test suite for DictInfoTolist wrapper."""
from __future__ import annotations
from typing import Any
import numpy as np import numpy as np
import pytest import pytest
import gymnasium as gym import gymnasium as gym
from gymnasium.wrappers.vector import DictInfoToList, RecordEpisodeStatistics from gymnasium.core import ObsType
from gymnasium.spaces import Discrete
from gymnasium.utils.env_checker import data_equivalence
from gymnasium.vector import VectorEnv
from gymnasium.wrappers.vector import DictInfoToList
ENV_ID = "CartPole-v1" def test_usage_in_vector_env(env_id: str = "CartPole-v1", num_envs: int = 3):
NUM_ENVS = 3 env = gym.make(env_id, disable_env_checker=True)
ENV_STEPS = 50 vector_env = gym.make_vec(env_id, num_envs=num_envs)
SEED = 42
def test_usage_in_vector_env():
env = gym.make(ENV_ID, disable_env_checker=True)
vector_env = gym.make_vec(ENV_ID, num_envs=NUM_ENVS, vectorization_mode="sync")
DictInfoToList(vector_env) DictInfoToList(vector_env)
@@ -23,40 +24,140 @@ def test_usage_in_vector_env():
DictInfoToList(env) DictInfoToList(env)
def test_info_to_list(): class ResetOptionAsInfo(VectorEnv):
env_to_wrap = gym.make_vec(ENV_ID, num_envs=NUM_ENVS, vectorization_mode="sync") """Minimal implementation to test the conversion of vector dict info to list info."""
wrapped_env = DictInfoToList(env_to_wrap)
wrapped_env.action_space.seed(SEED)
_, info = wrapped_env.reset(seed=SEED)
assert isinstance(info, list)
assert len(info) == NUM_ENVS
for _ in range(ENV_STEPS): def reset(
action = wrapped_env.action_space.sample() self,
_, _, terminateds, truncateds, list_info = wrapped_env.step(action) *,
for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)): seed: int | list[int] | None = None,
if terminated or truncated: options: dict[str, Any] | None = None, # options are passed are the info output
assert "final_observation" in list_info[i] ) -> tuple[ObsType, dict[str, Any]]:
else: return None, options
assert "final_observation" not in list_info[i]
def test_info_to_list_statistics(): def test_update_info():
env_to_wrap = gym.make_vec(ENV_ID, num_envs=NUM_ENVS, vectorization_mode="sync") env = DictInfoToList(ResetOptionAsInfo())
wrapped_env = DictInfoToList(RecordEpisodeStatistics(env_to_wrap))
_, info = wrapped_env.reset(seed=SEED)
wrapped_env.action_space.seed(SEED)
assert isinstance(info, list)
assert len(info) == NUM_ENVS
for _ in range(ENV_STEPS): # Test num-envs==1 then expand_dims(sub-env-info) == vector-infos
action = wrapped_env.action_space.sample() env.unwrapped.num_envs = 1
_, _, terminateds, truncateds, list_info = wrapped_env.step(action)
for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)): vector_infos = {
if terminated or truncated: "a": np.array([0]),
assert "episode" in list_info[i] "b": np.array([0.0]),
for stats in ["r", "l", "t"]: "c": np.array([None], dtype=object),
assert stats in list_info[i]["episode"] "d": np.zeros(
assert np.isscalar(list_info[i]["episode"][stats]) (
else: 1,
assert "episode" not in list_info[i] 2,
)
),
"e": np.array([Discrete(1)], dtype=object),
"_a": np.array([True]),
"_b": np.array([True]),
"_c": np.array([True]),
"_d": np.array([True]),
"_e": np.array([True]),
}
_, list_info = env.reset(options=vector_infos)
expected_list_info = [
{
"a": np.int64(0),
"b": np.float64(0.0),
"c": None,
"d": np.zeros((2,)),
"e": Discrete(1),
}
]
assert data_equivalence(list_info, expected_list_info)
# Thought: num-envs>1 then vector-infos should have the same structure as sub-env-info
env.unwrapped.num_envs = 3
vector_infos = {
"a": np.array([0, 1, 2]),
"b": np.array([0.0, 1.0, 2.0]),
"c": np.array([None, None, None], dtype=object),
"d": np.zeros((3, 2)),
"e": np.array([Discrete(1), Discrete(2), Discrete(3)], dtype=object),
"_a": np.array([True, True, True]),
"_b": np.array([True, True, True]),
"_c": np.array([True, True, True]),
"_d": np.array([True, True, True]),
"_e": np.array([True, True, True]),
}
_, list_info = env.reset(options=vector_infos)
expected_list_info = [
{
"a": np.int64(0),
"b": np.float64(0.0),
"c": None,
"d": np.zeros((2,)),
"e": Discrete(1),
},
{
"a": np.int64(1),
"b": np.float64(1.0),
"c": None,
"d": np.zeros((2,)),
"e": Discrete(2),
},
{
"a": np.int64(2),
"b": np.float64(2.0),
"c": None,
"d": np.zeros((2,)),
"e": Discrete(3),
},
]
assert list_info[0].keys() == expected_list_info[0].keys()
for key in list_info[0].keys():
assert data_equivalence(list_info[0][key], expected_list_info[0][key])
assert data_equivalence(list_info, expected_list_info)
# Test different structures of sub-infos
env.unwrapped.num_envs = 3
vector_infos = {
"a": np.array([1, 0, 0]),
"_a": np.array([True, False, False]),
"b": np.array([1.0, 0.0, 0.0]),
"_b": np.array([True, False, False]),
"c": np.array([None, None, None], dtype=object),
"_c": np.array([False, True, False]),
"_d": np.array([False, True, False]),
"d": np.zeros((3, 2)),
"e": np.array([None, None, Discrete(3)], dtype=object),
"_e": np.array([False, False, True]),
}
_, list_info = env.reset(options=vector_infos)
expected_list_info = [
{"a": np.int64(1), "b": np.float64(1.0)},
{"c": None, "d": np.zeros((2,))},
{"e": Discrete(3)},
]
assert data_equivalence(list_info, expected_list_info)
# Test recursive structure
env.unwrapped.num_envs = 3
vector_infos = {
"episode": {
"a": np.array([1, 2, 0]),
"b": np.array([1.0, 2.0, 0.0]),
"_a": np.array([True, True, False]),
"_b": np.array([True, True, False]),
},
"_episode": np.array([True, True, False]),
"a": np.array([0, 1, 2]),
"_a": np.array([False, True, True]),
}
_, list_info = env.reset(options=vector_infos)
expected_list_info = [
{"episode": {"a": np.int64(1), "b": np.float64(1.0)}},
{"episode": {"a": np.int64(2), "b": np.float64(2.0)}, "a": np.int64(1)},
{"a": np.int64(2)},
]
assert data_equivalence(list_info, expected_list_info)

View File

@@ -0,0 +1,71 @@
import pytest
import gymnasium as gym
from gymnasium.utils.env_checker import data_equivalence
from gymnasium.vector import VectorEnv
@pytest.mark.parametrize("num_envs", (1, 3))
def test_record_episode_statistics(num_envs, env_id="CartPole-v1", num_steps=100):
wrapper_vector_env: VectorEnv = gym.wrappers.vector.RecordEpisodeStatistics(
gym.make_vec(id=env_id, num_envs=num_envs, vectorization_mode="sync"),
)
vector_wrapper_env = gym.make_vec(
id=env_id,
num_envs=num_envs,
vectorization_mode="sync",
wrappers=(gym.wrappers.RecordEpisodeStatistics,),
)
assert wrapper_vector_env.action_space == vector_wrapper_env.action_space
assert wrapper_vector_env.observation_space == vector_wrapper_env.observation_space
assert (
wrapper_vector_env.single_action_space == vector_wrapper_env.single_action_space
)
assert (
wrapper_vector_env.single_observation_space
== vector_wrapper_env.single_observation_space
)
assert wrapper_vector_env.num_envs == vector_wrapper_env.num_envs
wrapper_vector_obs, wrapper_vector_info = wrapper_vector_env.reset(seed=123)
vector_wrapper_obs, vector_wrapper_info = vector_wrapper_env.reset(seed=123)
assert data_equivalence(wrapper_vector_obs, vector_wrapper_obs)
assert data_equivalence(wrapper_vector_info, vector_wrapper_info)
for _ in range(num_steps):
action = wrapper_vector_env.action_space.sample()
(
wrapper_vector_obs,
wrapper_vector_reward,
wrapper_vector_terminated,
wrapper_vector_truncated,
wrapper_vector_info,
) = wrapper_vector_env.step(action)
(
vector_wrapper_obs,
vector_wrapper_reward,
vector_wrapper_terminated,
vector_wrapper_truncated,
vector_wrapper_info,
) = vector_wrapper_env.step(action)
data_equivalence(wrapper_vector_obs, vector_wrapper_obs)
data_equivalence(wrapper_vector_reward, vector_wrapper_reward)
data_equivalence(wrapper_vector_terminated, vector_wrapper_terminated)
data_equivalence(wrapper_vector_truncated, vector_wrapper_truncated)
if "episode" in wrapper_vector_info:
assert "episode" in vector_wrapper_info
wrapper_vector_time = wrapper_vector_info["episode"].pop("t")
vector_wrapper_time = vector_wrapper_info["episode"].pop("t")
assert wrapper_vector_time.shape == vector_wrapper_time.shape
assert wrapper_vector_time.dtype == vector_wrapper_time.dtype
data_equivalence(wrapper_vector_info, vector_wrapper_info)
wrapper_vector_env.close()
vector_wrapper_env.close()

View File

@@ -42,21 +42,21 @@ def custom_environments():
("CustomDictEnv-v0", "FilterObservation", {"filter_keys": ["a"]}), ("CustomDictEnv-v0", "FilterObservation", {"filter_keys": ["a"]}),
("CartPole-v1", "FlattenObservation", {}), ("CartPole-v1", "FlattenObservation", {}),
("CarRacing-v2", "GrayscaleObservation", {}), ("CarRacing-v2", "GrayscaleObservation", {}),
# ("CarRacing-v2", "ResizeObservation", {"shape": (35, 45)}), ("CarRacing-v2", "ResizeObservation", {"shape": (35, 45)}),
("CarRacing-v2", "ReshapeObservation", {"shape": (96, 48, 6)}), ("CarRacing-v2", "ReshapeObservation", {"shape": (96, 48, 6)}),
("CartPole-v1", "RescaleObservation", {"min_obs": 0, "max_obs": 1}), ("CartPole-v1", "RescaleObservation", {"min_obs": 0, "max_obs": 1}),
("CartPole-v1", "DtypeObservation", {"dtype": np.int32}), ("CartPole-v1", "DtypeObservation", {"dtype": np.int32}),
# ("CartPole-v1", "RenderObservation", {}), # ("CartPole-v1", "RenderObservation", {}), # not implemented
# ("CartPole-v1", "TimeAwareObservation", {}), # ("CartPole-v1", "TimeAwareObservation", {}), # not implemented
# ("CartPole-v1", "FrameStackObservation", {}), # ("CartPole-v1", "FrameStackObservation", {}), # not implemented
# ("CartPole-v1", "DelayObservation", {}), # ("CartPole-v1", "DelayObservation", {}), # not implemented
("MountainCarContinuous-v0", "ClipAction", {}), ("MountainCarContinuous-v0", "ClipAction", {}),
( (
"MountainCarContinuous-v0", "MountainCarContinuous-v0",
"RescaleAction", "RescaleAction",
{"min_action": 1, "max_action": 2}, {"min_action": 1, "max_action": 2},
), ),
("CartPole-v1", "ClipReward", {"min_reward": 0.25, "max_reward": 0.75}), ("CartPole-v1", "ClipReward", {"min_reward": -0.25, "max_reward": 0.75}),
), ),
) )
def test_vector_wrapper_equivalence( def test_vector_wrapper_equivalence(