Change autoreset order (#808)

Co-authored-by: pseudo-rnd-thoughts <mark.m.towers@gmail.com>
This commit is contained in:
Ariel Kwiatkowski
2023-12-03 19:50:18 +01:00
committed by GitHub
parent 967bbf5823
commit e9c66e4225
18 changed files with 591 additions and 364 deletions

View File

@@ -151,6 +151,8 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
self.steps = jnp.zeros(self.num_envs, dtype=jnp.int32)
self.autoreset_envs = jnp.zeros(self.num_envs, dtype=jnp.bool_)
self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
if self.render_mode == "rgb_array":
@@ -214,10 +216,9 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
info = self.func_env.transition_info(self.state, action, next_state)
done = jnp.logical_or(terminated, truncated)
if jnp.any(done):
final_obs = self.func_env.observation(next_state)
to_reset = jnp.where(done)[0]
if jnp.any(self.autoreset_envs):
to_reset = jnp.where(self.autoreset_envs)[0]
reset_count = to_reset.shape[0]
rng, self.rng = jrng.split(self.rng)
@@ -228,34 +229,16 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
next_state = self.state.at[to_reset].set(new_initials)
self.steps = self.steps.at[to_reset].set(0)
# Get the final observations and infos
info["final_observation"] = np.array([None for _ in range(self.num_envs)])
info["final_info"] = np.array([None for _ in range(self.num_envs)])
info["_final_observation"] = np.array([False for _ in range(self.num_envs)])
info["_final_info"] = np.array([False for _ in range(self.num_envs)])
# TODO: this can maybe be optimized, but right now I don't know how
for i in to_reset:
info["final_observation"][i] = final_obs[i]
info["final_info"][i] = {
k: v[i]
for k, v in info.items()
if k
not in {
"final_observation",
"final_info",
"_final_observation",
"_final_info",
}
}
info["_final_observation"][i] = True
info["_final_info"][i] = True
self.autoreset_envs = done
observation = self.func_env.observation(next_state)
observation = jax_to_numpy(observation)
reward = jax_to_numpy(reward)
terminated = jax_to_numpy(terminated)
truncated = jax_to_numpy(truncated)
self.state = next_state
return observation, reward, terminated, truncated, info

View File

@@ -639,6 +639,7 @@ def _async_worker(
env = env_fn()
observation_space = env.observation_space
action_space = env.action_space
autoreset = False
parent_pipe.close()
@@ -653,8 +654,13 @@ def _async_worker(
observation_space, index, observation, shared_memory
)
observation = None
autoreset = False
pipe.send(((observation, info), True))
elif command == "step":
if autoreset:
observation, info = env.reset()
reward, terminated, truncated = 0, False, False
else:
(
observation,
reward,
@@ -662,11 +668,7 @@ def _async_worker(
truncated,
info,
) = env.step(data)
if terminated or truncated:
old_observation, old_info = observation, info
observation, info = env.reset()
info["final_observation"] = old_observation
info["final_info"] = old_info
autoreset = terminated or truncated
if shared_memory:
write_to_shared_memory(

View File

@@ -98,6 +98,8 @@ class SyncVectorEnv(VectorEnv):
self._terminations = np.zeros((self.num_envs,), dtype=np.bool_)
self._truncations = np.zeros((self.num_envs,), dtype=np.bool_)
self._autoreset_envs = np.zeros((self.num_envs,), dtype=np.bool_)
def reset(
self,
*,
@@ -150,22 +152,21 @@ class SyncVectorEnv(VectorEnv):
actions = iterate(self.action_space, actions)
observations, infos = [], {}
for i, (env, action) in enumerate(zip(self.envs, actions)):
for i, action in enumerate(actions):
if self._autoreset_envs[i]:
env_obs, env_info = self.envs[i].reset()
self._rewards[i] = 0.0
self._terminations[i] = False
self._truncations[i] = False
else:
(
env_obs,
self._rewards[i],
self._terminations[i],
self._truncations[i],
env_info,
) = env.step(action)
# If sub-environments terminates or truncates then save the obs and info to the batched info
if self._terminations[i] or self._truncations[i]:
old_observation, old_info = env_obs, env_info
env_obs, env_info = env.reset()
env_info["final_observation"] = old_observation
env_info["final_info"] = old_info
) = self.envs[i].step(action)
observations.append(env_obs)
infos = self._add_info(infos, env_info, i)
@@ -174,6 +175,7 @@ class SyncVectorEnv(VectorEnv):
self._observations = concatenate(
self.single_observation_space, observations, self._observations
)
self._autoreset_envs = np.logical_or(self._terminations, self._truncations)
return (
deepcopy(self._observations) if self.copy else self._observations,

View File

@@ -231,7 +231,7 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
return self
def _add_info(
self, infos: dict[str, Any], info: dict[str, Any], env_num: int
self, vector_infos: dict[str, Any], env_info: dict[str, Any], env_num: int
) -> dict[str, Any]:
"""Add env info to the info dictionary of the vectorized environment.
@@ -241,48 +241,51 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
whether or not the i-indexed environment has this `info`.
Args:
infos (dict): the infos of the vectorized environment
info (dict): the info coming from the single environment
vector_infos (dict): the infos of the vectorized environment
env_info (dict): the info coming from the single environment
env_num (int): the index of the single environment
Returns:
infos (dict): the (updated) infos of the vectorized environment
"""
for k in info.keys():
if k not in infos:
info_array, array_mask = self._init_info_arrays(type(info[k]))
for key, value in env_info.items():
# If value is a dictionary, then we apply the `_add_info` recursively.
if isinstance(value, dict):
array = self._add_info(vector_infos.get(key, {}), value, env_num)
# Otherwise, we are a base case to group the data
else:
info_array, array_mask = infos[k], infos[f"_{k}"]
info_array[env_num], array_mask[env_num] = info[k], True
infos[k], infos[f"_{k}"] = info_array, array_mask
return infos
def _init_info_arrays(self, dtype: type) -> tuple[np.ndarray, np.ndarray]:
"""Initialize the info array.
Initialize the info array. If the dtype is numeric
the info array will have the same dtype, otherwise
will be an array of `None`. Also, a boolean array
of the same length is returned. It will be used for
assessing which environment has info data.
Args:
dtype (type): data type of the info coming from the env.
Returns:
array (np.ndarray): the initialized info array.
array_mask (np.ndarray): the initialized boolean array.
"""
if dtype in [int, float, bool] or issubclass(dtype, np.number):
array = np.zeros(self.num_envs, dtype=dtype)
# If the key doesn't exist in the vector infos, then we can create an array of that batch type
if key not in vector_infos:
if type(value) in [int, float, bool] or issubclass(
type(value), np.number
):
array = np.zeros(self.num_envs, dtype=type(value))
elif isinstance(value, np.ndarray):
# We assume that all instances of the np.array info are of the same shape
array = np.zeros(
(self.num_envs, *value.shape), dtype=value.dtype
)
else:
array = np.zeros(self.num_envs, dtype=object)
array[:] = None
array_mask = np.zeros(self.num_envs, dtype=bool)
return array, array_mask
# For unknown objects, we use a Numpy object array
array = np.full(self.num_envs, fill_value=None, dtype=object)
# Otherwise, just use the array that already exists
else:
array = vector_infos[key]
# Assign the data in the `env_num` position
# We only want to run this for the base-case data (not recursive data forcing the ugly function structure)
array[env_num] = value
# Get the array mask and if it doesn't already exist then create a zero bool array
array_mask = vector_infos.get(
f"_{key}", np.zeros(self.num_envs, dtype=np.bool_)
)
array_mask[env_num] = True
# Update the vector info with the updated data and mask information
vector_infos[key], vector_infos[f"_{key}"] = array, array_mask
return vector_infos
def __del__(self):
"""Closes the vector environment."""
@@ -441,23 +444,23 @@ class VectorObservationWrapper(VectorWrapper):
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
obs, info = self.env.reset(seed=seed, options=options)
return self.vector_observation(obs), info
observations, infos = self.env.reset(seed=seed, options=options)
return self.observation(observations), infos
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
observation, reward, termination, truncation, info = self.env.step(actions)
observations, rewards, terminations, truncations, infos = self.env.step(actions)
return (
self.vector_observation(observation),
reward,
termination,
truncation,
self.update_final_obs(info),
self.observation(observations),
rewards,
terminations,
truncations,
infos,
)
def vector_observation(self, observation: ObsType) -> ObsType:
def observation(self, observation: ObsType) -> ObsType:
"""Defines the vector observation transformation.
Args:
@@ -468,25 +471,6 @@ class VectorObservationWrapper(VectorWrapper):
"""
raise NotImplementedError
def single_observation(self, observation: ObsType) -> ObsType:
"""Defines the single observation transformation.
Args:
observation: A single observation from the environment
Returns:
The transformed observation
"""
raise NotImplementedError
def update_final_obs(self, info: dict[str, Any]) -> dict[str, Any]:
"""Updates the `final_obs` in the info using `single_observation`."""
if "final_observation" in info:
for i, obs in enumerate(info["final_observation"]):
if obs is not None:
info["final_observation"][i] = self.single_observation(obs)
return info
class VectorActionWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the actions.
@@ -522,14 +506,14 @@ class VectorRewardWrapper(VectorWrapper):
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
"""Steps through the environment returning a reward modified by :meth:`reward`."""
observation, reward, termination, truncation, info = self.env.step(actions)
return observation, self.rewards(reward), termination, truncation, info
observations, rewards, terminations, truncations, infos = self.env.step(actions)
return observations, self.rewards(rewards), terminations, truncations, infos
def rewards(self, reward: ArrayType) -> ArrayType:
def rewards(self, rewards: ArrayType) -> ArrayType:
"""Transform the reward before returning it.
Args:
reward (array): the reward to transform
rewards (array): the reward to transform
Returns:
array: the transformed reward

View File

@@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Any, SupportsFloat
import gymnasium as gym
from gymnasium import logger
from gymnasium.core import ActType, ObsType, RenderFrame
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperObsType
from gymnasium.error import ResetNeeded
from gymnasium.utils.passive_env_checker import (
check_action_space,
@@ -196,6 +196,15 @@ class Autoreset(
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)
self.autoreset = False
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment and sets autoreset to False preventing."""
self.autoreset = False
return super().reset(seed=seed, options=options)
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
@@ -207,24 +216,13 @@ class Autoreset(
Returns:
The autoreset environment :meth:`step`
"""
if self.autoreset:
obs, info = self.env.reset()
reward, terminated, truncated = 0.0, False, False
else:
obs, reward, terminated, truncated, info = self.env.step(action)
if terminated or truncated:
new_obs, new_info = self.env.reset()
assert (
"final_observation" not in new_info
), f'new info dict already contains "final_observation", info keys: {new_info.keys()}'
assert (
"final_info" not in new_info
), f'new info dict already contains "final_observation", info keys: {new_info.keys()}'
new_info["final_observation"] = obs
new_info["final_info"] = info
obs = new_obs
info = new_info
self.autoreset = terminated or truncated
return obs, reward, terminated, truncated, info
@@ -470,14 +468,14 @@ class RecordEpisodeStatistics(
def __init__(
self,
env: gym.Env[ObsType, ActType],
buffer_length: int | None = 100,
buffer_length: int = 100,
stats_key: str = "episode",
):
"""This wrapper will keep track of cumulative rewards and episode lengths.
Args:
env (Env): The environment to apply the wrapper
buffer_length: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
buffer_length: The size of the buffers :attr:`return_queue`, :attr:`length_queue` and :attr:`time_queue`
stats_key: The info key for the episode statistics
"""
gym.utils.RecordConstructorArgs.__init__(self)
@@ -520,6 +518,7 @@ class RecordEpisodeStatistics(
self.length_queue.append(self.episode_lengths)
self.episode_count += 1
self.episode_start_time = time.perf_counter()
return obs, reward, terminated, truncated, info

View File

@@ -444,7 +444,7 @@ class NormalizeObservation(
Change logs:
* v0.21.0 - Initially add
* v1.0.0 - Add `update_running_mean` attribute to allow disabling of updating the running mean / standard
* v1.0.0 - Add `update_running_mean` attribute to allow disabling of updating the running mean / standard, particularly useful for evaluation time.
"""
def __init__(self, env: gym.Env[ObsType, ActType], epsilon: float = 1e-8):

View File

@@ -62,14 +62,21 @@ class RecordEpisodeStatistics(VectorWrapper):
None, None], dtype=object)}
"""
def __init__(self, env: VectorEnv, deque_size: int = 100):
def __init__(
self,
env: VectorEnv,
deque_size: int = 100,
stats_key: str = "episode",
):
"""This wrapper will keep track of cumulative rewards and episode lengths.
Args:
env (Env): The environment to apply the wrapper
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
stats_key: The info key to save the data
"""
super().__init__(env)
self._stats_key = stats_key
self.episode_count = 0
@@ -77,6 +84,7 @@ class RecordEpisodeStatistics(VectorWrapper):
self.episode_returns: np.ndarray = np.zeros(())
self.episode_lengths: np.ndarray = np.zeros(())
self.time_queue = deque(maxlen=deque_size)
self.return_queue = deque(maxlen=deque_size)
self.length_queue = deque(maxlen=deque_size)
@@ -88,11 +96,9 @@ class RecordEpisodeStatistics(VectorWrapper):
"""Resets the environment using kwargs and resets the episode returns and lengths."""
obs, info = super().reset(seed=seed, options=options)
self.episode_start_times = np.full(
self.num_envs, time.perf_counter(), dtype=np.float32
)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
self.episode_start_times = np.full(self.num_envs, time.perf_counter())
self.episode_returns = np.zeros(self.num_envs)
self.episode_lengths = np.zeros(self.num_envs)
return obs, info
@@ -110,7 +116,7 @@ class RecordEpisodeStatistics(VectorWrapper):
assert isinstance(
infos, dict
), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order."
), f"`vector.RecordEpisodeStatistics` requires `info` type to be `dict`, its actual type is {type(infos)}. This may be due to usage of other wrappers in the wrong order."
self.episode_returns += rewards
self.episode_lengths += 1
@@ -119,25 +125,25 @@ class RecordEpisodeStatistics(VectorWrapper):
num_dones = np.sum(dones)
if num_dones:
if "episode" in infos or "_episode" in infos:
if self._stats_key in infos or f"_{self._stats_key}" in infos:
raise ValueError(
"Attempted to add episode stats when they already exist"
f"Attempted to add episode stats when they already exist, info keys: {list(infos.keys())}"
)
else:
infos["episode"] = {
episode_time_length = np.round(
time.perf_counter() - self.episode_start_times, 6
)
infos[self._stats_key] = {
"r": np.where(dones, self.episode_returns, 0.0),
"l": np.where(dones, self.episode_lengths, 0),
"t": np.where(
dones,
np.round(time.perf_counter() - self.episode_start_times, 6),
0.0,
),
"t": np.where(dones, episode_time_length, 0.0),
}
infos["_episode"] = dones
infos[f"_{self._stats_key}"] = dones
self.episode_count += num_dones
for i in np.where(dones):
self.time_queue.extend(episode_time_length[i])
self.return_queue.extend(self.episode_returns[i])
self.length_queue.extend(self.episode_lengths[i])

View File

@@ -3,6 +3,8 @@ from __future__ import annotations
from typing import Any
import numpy as np
from gymnasium.core import ActType, ObsType
from gymnasium.vector.vector_env import ArrayType, VectorEnv, VectorWrapper
@@ -78,6 +80,7 @@ class DictInfoToList(VectorWrapper):
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, list[dict[str, Any]]]:
"""Steps through the environment, convert dict info to list."""
observation, reward, terminated, truncated, infos = self.env.step(actions)
assert isinstance(infos, dict)
list_info = self._convert_info_to_list(infos)
return observation, reward, terminated, truncated, list_info
@@ -90,11 +93,12 @@ class DictInfoToList(VectorWrapper):
) -> tuple[ObsType, list[dict[str, Any]]]:
"""Resets the environment using kwargs."""
obs, infos = self.env.reset(seed=seed, options=options)
assert isinstance(infos, dict)
list_info = self._convert_info_to_list(infos)
return obs, list_info
def _convert_info_to_list(self, infos: dict) -> list[dict[str, Any]]:
def _convert_info_to_list(self, vector_infos: dict) -> list[dict[str, Any]]:
"""Convert the dict info to list.
Convert the dict info of the vectorized environment
@@ -102,52 +106,28 @@ class DictInfoToList(VectorWrapper):
has the info of the i-th environment.
Args:
infos (dict): info dict coming from the env.
vector_infos (dict): info dict coming from the env.
Returns:
list_info (list): converted info.
"""
list_info = [{} for _ in range(self.num_envs)]
list_info = self._process_episode_statistics(infos, list_info)
for k in infos:
if k.startswith("_"):
for key, value in vector_infos.items():
if key.startswith("_"):
continue
for i, has_info in enumerate(infos[f"_{k}"]):
if isinstance(value, dict):
value_list_info = self._convert_info_to_list(value)
for env_num, (env_info, has_info) in enumerate(
zip(value_list_info, vector_infos[f"_{key}"])
):
if has_info:
list_info[i][k] = infos[k][i]
return list_info
# todo - I think this function should be more general for any information
def _process_episode_statistics(self, infos: dict, list_info: list) -> list[dict]:
"""Process episode statistics.
`RecordEpisodeStatistics` wrapper add extra
information to the info. This information are in
the form of a dict of dict. This method process these
information and add them to the info.
`RecordEpisodeStatistics` info contains the keys
"r", "l", "t" which represents "cumulative reward",
"episode length", "elapsed time since instantiation of wrapper".
Args:
infos (dict): infos coming from `RecordEpisodeStatistics`.
list_info (list): info of the current vectorized environment.
Returns:
list_info (list): updated info.
"""
episode_statistics = infos.pop("episode", False)
if not episode_statistics:
return list_info
episode_statistics_mask = infos.pop("_episode")
for i, has_info in enumerate(episode_statistics_mask):
list_info[env_num][key] = env_info
else:
assert isinstance(value, np.ndarray)
for env_num, has_info in enumerate(vector_infos[f"_{key}"]):
if has_info:
list_info[i]["episode"] = {}
list_info[i]["episode"]["r"] = episode_statistics["r"][i]
list_info[i]["episode"]["l"] = episode_statistics["l"][i]
list_info[i]["episode"]["t"] = episode_statistics["t"][i]
list_info[env_num][key] = value[env_num]
return list_info

View File

@@ -34,9 +34,9 @@ class NormalizeObservation(VectorObservationWrapper, gym.utils.RecordConstructor
>>> for _ in range(100):
... obs, *_ = envs.step(envs.action_space.sample())
>>> np.mean(obs)
-0.017698428
0.024251968
>>> np.std(obs)
0.62041104
0.62259156
>>> envs.close()
Example with the normalize reward wrapper:
@@ -48,9 +48,9 @@ class NormalizeObservation(VectorObservationWrapper, gym.utils.RecordConstructor
>>> for _ in range(100):
... obs, *_ = envs.step(envs.action_space.sample())
>>> np.mean(obs)
-0.28381696
-0.2359734
>>> np.std(obs)
1.21742
1.1938739
>>> envs.close()
"""
@@ -81,29 +81,15 @@ class NormalizeObservation(VectorObservationWrapper, gym.utils.RecordConstructor
"""Sets the property to freeze/continue the running mean calculation of the observation statistics."""
self._update_running_mean = setting
def vector_observation(self, observation: ObsType) -> ObsType:
def observation(self, observations: ObsType) -> ObsType:
"""Defines the vector observation normalization function.
Args:
observation: A vector observation from the environment
observations: A vector observation from the environment
Returns:
the normalized observation
"""
return self._normalize_observations(observation)
def single_observation(self, observation: ObsType) -> ObsType:
"""Defines the single observation normalization function.
Args:
observation: A single observation from the environment
Returns:
The normalized observation
"""
return self._normalize_observations(observation[None])
def _normalize_observations(self, observations: ObsType) -> ObsType:
if self._update_running_mean:
self.obs_rms.update(observations)
return (observations - self.obs_rms.mean) / np.sqrt(

View File

@@ -37,13 +37,10 @@ class TransformObservation(VectorObservationWrapper):
>>> def scale_and_shift(obs):
... return (obs - 1.0) * 2.0
...
>>> def vector_scale_and_shift(obs):
... return (obs - 1.0) * 2.0
...
>>> import gymnasium as gym
>>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
>>> new_obs_space = Box(low=envs.observation_space.low, high=envs.observation_space.high)
>>> envs = TransformObservation(envs, single_func=scale_and_shift, vector_func=vector_scale_and_shift)
>>> envs = TransformObservation(envs, func=scale_and_shift, observation_space=new_obs_space)
>>> obs, info = envs.reset(seed=123)
>>> obs
array([[-1.9635296, -2.0892358, -2.055928 , -2.0631256],
@@ -55,16 +52,14 @@ class TransformObservation(VectorObservationWrapper):
def __init__(
self,
env: VectorEnv,
vector_func: Callable[[ObsType], Any],
single_func: Callable[[ObsType], Any],
func: Callable[[ObsType], Any],
observation_space: Space | None = None,
):
"""Constructor for the transform observation wrapper.
Args:
env: The vector environment to wrap
vector_func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``.
single_func: A function that will transform an individual observation, this function will be used for the final observation from the environment and is returned under ``info`` and not the normal observation.
func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``.
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``.
"""
super().__init__(env)
@@ -72,16 +67,11 @@ class TransformObservation(VectorObservationWrapper):
if observation_space is not None:
self.observation_space = observation_space
self.vector_func = vector_func
self.single_func = single_func
self.func = func
def vector_observation(self, observation: ObsType) -> ObsType:
def observation(self, observations: ObsType) -> ObsType:
"""Apply function to the vector observation."""
return self.vector_func(observation)
def single_observation(self, observation: ObsType) -> ObsType:
"""Apply function to the single observation."""
return self.single_func(observation)
return self.func(observations)
class VectorizeTransformObservation(VectorObservationWrapper):
@@ -158,16 +148,16 @@ class VectorizeTransformObservation(VectorObservationWrapper):
self.same_out = self.observation_space == self.env.observation_space
self.out = create_empty_array(self.single_observation_space, self.num_envs)
def vector_observation(self, observation: ObsType) -> ObsType:
def observation(self, observations: ObsType) -> ObsType:
"""Iterates over the vector observations applying the single-agent wrapper ``observation`` then concatenates the observations together again."""
if self.same_out:
return concatenate(
self.single_observation_space,
tuple(
self.wrapper.func(obs)
for obs in iterate(self.observation_space, observation)
for obs in iterate(self.observation_space, observations)
),
observation,
observations,
)
else:
return deepcopy(
@@ -175,16 +165,12 @@ class VectorizeTransformObservation(VectorObservationWrapper):
self.single_observation_space,
tuple(
self.wrapper.func(obs)
for obs in iterate(self.env.observation_space, observation)
for obs in iterate(self.env.observation_space, observations)
),
self.out,
)
)
def single_observation(self, observation: ObsType) -> ObsType:
"""Transforms a single observation using the wrapper transformation function."""
return self.wrapper.func(observation)
class FilterObservation(VectorizeTransformObservation):
"""Vector wrapper for filtering dict or tuple observation spaces.

View File

@@ -6,8 +6,14 @@ import jax.numpy as jnp # noqa: E402
import jax.random as jrng # noqa: E402
import numpy as np # noqa: E402
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional # noqa: E402
from gymnasium.envs.phys2d.pendulum import PendulumFunctional # noqa: E402
from gymnasium.envs.phys2d.cartpole import ( # noqa: E402
CartPoleFunctional,
CartPoleJaxVectorEnv,
)
from gymnasium.envs.phys2d.pendulum import ( # noqa: E402
PendulumFunctional,
PendulumJaxVectorEnv,
)
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
@@ -105,3 +111,34 @@ def test_vmap(env_class):
assert obs.dtype == jnp.float32
state = next_state
@pytest.mark.parametrize("env_class", [CartPoleJaxVectorEnv, PendulumJaxVectorEnv])
def test_vectorized(env_class):
env = env_class(num_envs=10)
env.action_space.seed(0)
obs, info = env.reset(seed=0)
assert obs.shape == (10,) + env.single_observation_space.shape
assert isinstance(obs, np.ndarray)
assert isinstance(info, dict)
for t in range(100):
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
assert obs.shape == (10,) + env.single_observation_space.shape
assert isinstance(obs, np.ndarray)
assert reward.shape == (10,)
assert isinstance(reward, np.ndarray)
assert terminated.shape == (10,)
assert isinstance(terminated, np.ndarray)
assert truncated.shape == (10,)
assert isinstance(truncated, np.ndarray)
assert isinstance(info, dict)
# These were removed in the new autoreset order
assert "final_observation" not in info
assert "final_info" not in info
assert "_final_observation" not in info
assert "_final_info" not in info

View File

@@ -33,7 +33,7 @@ def test_create_async_vector_env(shared_memory):
@pytest.mark.parametrize("shared_memory", [True, False])
def test_reset_async_vector_env(shared_memory):
"""Test the reset of an sync vector environment with or without shared memory."""
"""Test the reset of async vector environment with or without shared memory."""
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)

View File

@@ -6,6 +6,7 @@ import numpy as np
import pytest
from gymnasium.spaces import Discrete
from gymnasium.utils.env_checker import data_equivalence
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
from tests.testing_env import GenericTestEnv
from tests.vector.testing_utils import make_env
@@ -29,6 +30,7 @@ def test_vector_env_equal(shared_memory):
async_observations, async_infos = async_env.reset(seed=0)
sync_observations, sync_infos = sync_env.reset(seed=0)
assert np.all(async_observations == sync_observations)
assert data_equivalence(async_infos, sync_infos)
for _ in range(num_steps):
actions = async_env.action_space.sample()
@@ -49,16 +51,11 @@ def test_vector_env_equal(shared_memory):
sync_infos,
) = sync_env.step(actions)
if any(sync_terminations) or any(sync_truncations):
assert "final_observation" in async_infos
assert "_final_observation" in async_infos
assert "final_observation" in sync_infos
assert "_final_observation" in sync_infos
assert np.all(async_observations == sync_observations)
assert np.all(async_rewards == sync_rewards)
assert np.all(async_terminations == sync_terminations)
assert np.all(async_truncations == sync_truncations)
assert data_equivalence(async_infos, sync_infos)
async_env.close()
sync_env.close()
@@ -115,14 +112,13 @@ def test_final_obs_info(vectoriser):
)
obs, _, termination, _, info = env.step([3])
assert obs == np.array([0]) and info == {"action": 3, "_action": np.array([True])}
obs, _, terminated, _, info = env.step([4])
assert (
obs == np.array([0])
and termination == np.array([True])
and info["reset"] == np.array([True])
)
assert "final_observation" in info and "final_info" in info
assert info["final_observation"] == np.array([0]) and info["final_info"] == {
"action": 3
}
env.close()

View File

@@ -1,66 +1,164 @@
"""Test the vector environment information."""
from __future__ import annotations
from typing import Any, SupportsFloat
import numpy as np
import pytest
import gymnasium as gym
from gymnasium.vector.sync_vector_env import SyncVectorEnv
from tests.vector.testing_utils import make_env
from gymnasium.core import ActType, ObsType
from gymnasium.spaces import Box, Discrete
from gymnasium.utils.env_checker import data_equivalence
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv, VectorEnv
ENV_ID = "CartPole-v1"
NUM_ENVS = 3
ENV_STEPS = 50
SEED = 42
def test_vector_add_info():
env = VectorEnv()
@pytest.mark.parametrize("vectorization_mode", ["async", "sync"])
def test_vector_env_info(vectorization_mode: str):
"""Test vector environment info for different vectorization modes."""
env = gym.make_vec(
ENV_ID,
num_envs=NUM_ENVS,
vectorization_mode=vectorization_mode,
# Test num-envs==1 then expand_dims(sub-env-info) == vector-infos
env.num_envs = 1
sub_env_info = {"a": 0, "b": 0.0, "c": None, "d": np.zeros((2,)), "e": Discrete(1)}
vector_infos = env._add_info({}, sub_env_info, 0)
expected_vector_infos = {
"a": np.array([0]),
"b": np.array([0.0]),
"c": np.array([None], dtype=object),
"d": np.zeros(
(
1,
2,
)
env.reset(seed=SEED)
for _ in range(ENV_STEPS):
env.action_space.seed(SEED)
action = env.action_space.sample()
_, _, terminateds, truncateds, infos = env.step(action)
if any(terminateds) or any(truncateds):
assert len(infos["final_observation"]) == NUM_ENVS
assert len(infos["_final_observation"]) == NUM_ENVS
),
"e": np.array([Discrete(1)], dtype=object),
"_a": np.array([True]),
"_b": np.array([True]),
"_c": np.array([True]),
"_d": np.array([True]),
"_e": np.array([True]),
}
assert data_equivalence(vector_infos, expected_vector_infos)
assert isinstance(infos["final_observation"], np.ndarray)
assert isinstance(infos["_final_observation"], np.ndarray)
# Thought: num-envs>1 then vector-infos should have the same structure as sub-env-info
env.num_envs = 3
sub_env_infos = [
{"a": 0, "b": 0.0, "c": None, "d": np.zeros((2,)), "e": Discrete(1)},
{"a": 1, "b": 1.0, "c": None, "d": np.zeros((2,)), "e": Discrete(2)},
{"a": 2, "b": 2.0, "c": None, "d": np.zeros((2,)), "e": Discrete(3)},
]
for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)):
if terminated or truncated:
assert infos["_final_observation"][i]
else:
assert not infos["_final_observation"][i]
assert infos["final_observation"][i] is None
vector_infos = {}
for i, info in enumerate(sub_env_infos):
vector_infos = env._add_info(vector_infos, info, i)
env.close()
expected_vector_infos = {
"a": np.array([0, 1, 2]),
"b": np.array([0.0, 1.0, 2.0]),
"c": np.array([None, None, None], dtype=object),
"d": np.zeros((3, 2)),
"e": np.array([Discrete(1), Discrete(2), Discrete(3)], dtype=object),
"_a": np.array([True, True, True]),
"_b": np.array([True, True, True]),
"_c": np.array([True, True, True]),
"_d": np.array([True, True, True]),
"_e": np.array([True, True, True]),
}
assert data_equivalence(vector_infos, expected_vector_infos)
# Test different structures of sub-infos
env.num_envs = 3
sub_env_infos = [
{"a": 1, "b": 1.0},
{"c": None, "d": np.zeros((2,))},
{"e": Discrete(3)},
]
vector_infos = {}
for i, info in enumerate(sub_env_infos):
vector_infos = env._add_info(vector_infos, info, i)
expected_vector_infos = {
"a": np.array([1, 0, 0]),
"b": np.array([1.0, 0.0, 0.0]),
"c": np.array([None, None, None], dtype=object),
"d": np.zeros((3, 2)),
"e": np.array([None, None, Discrete(3)], dtype=object),
"_a": np.array([True, False, False]),
"_b": np.array([True, False, False]),
"_c": np.array([False, True, False]),
"_d": np.array([False, True, False]),
"_e": np.array([False, False, True]),
}
assert data_equivalence(vector_infos, expected_vector_infos)
# Test recursive structure
env.num_envs = 3
sub_env_infos = [
{"episode": {"a": 1, "b": 1.0}},
{"episode": {"a": 2, "b": 2.0}, "a": 1},
{"a": 2},
]
vector_infos = {}
for i, info in enumerate(sub_env_infos):
vector_infos = env._add_info(vector_infos, info, i)
expected_vector_infos = {
"episode": {
"a": np.array([1, 2, 0]),
"b": np.array([1.0, 2.0, 0.0]),
"_a": np.array([True, True, False]),
"_b": np.array([True, True, False]),
},
"_episode": np.array([True, True, False]),
"a": np.array([0, 1, 2]),
"_a": np.array([False, True, True]),
}
assert data_equivalence(vector_infos, expected_vector_infos)
@pytest.mark.parametrize("concurrent_ends", [1, 2, 3])
def test_vector_env_info_concurrent_termination(concurrent_ends):
"""Test the vector environment information works with concurrent termination."""
# envs that need to terminate together will have the same action
actions = [0] * concurrent_ends + [1] * (NUM_ENVS - concurrent_ends)
envs = [make_env(ENV_ID, SEED) for _ in range(NUM_ENVS)]
envs = SyncVectorEnv(envs)
class ReturnInfoEnv(gym.Env):
def __init__(self, infos):
self.observation_space = Box(0, 1)
self.action_space = Box(0, 1)
for _ in range(ENV_STEPS):
_, _, terminateds, truncateds, infos = envs.step(actions)
if any(terminateds) or any(truncateds):
for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)):
if i < concurrent_ends:
assert terminated or truncated
assert infos["_final_observation"][i]
else:
assert not infos["_final_observation"][i]
assert infos["final_observation"][i] is None
return
self.infos = infos
envs.close()
def reset(
self,
*,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
return self.observation_space.sample(), self.infos[0]
def step(
self, action: ActType
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
return self.observation_space.sample(), 0, True, False, self.infos[1]
@pytest.mark.parametrize("vectorizer", [AsyncVectorEnv, SyncVectorEnv])
def test_vectorizers(vectorizer):
vec_env = vectorizer(
[
lambda: ReturnInfoEnv([{"a": 1}, {"c": np.array([1, 2])}]),
lambda: ReturnInfoEnv([{"a": 2, "b": 3}, {"c": np.array([3, 4])}]),
]
)
reset_expected_infos = {
"a": np.array([1, 2]),
"b": np.array([0, 3]),
"_a": np.array([True, True]),
"_b": np.array([False, True]),
}
step_expected_infos = {
"c": np.array([[1, 2], [3, 4]]),
"_c": np.array([True, True]),
}
_, reset_info = vec_env.reset()
assert data_equivalence(reset_info, reset_expected_infos)
_, _, _, _, step_info = vec_env.step(vec_env.action_space.sample())
assert data_equivalence(step_info, step_expected_infos)

View File

@@ -45,19 +45,15 @@ def test_autoreset_wrapper_autoreset():
assert info == {"count": 2}
obs, reward, terminated, truncated, info = env.step(action)
assert obs == np.array([0])
assert obs == np.array([3])
assert (terminated or truncated) is True
assert reward == 1
assert info == {
"count": 0,
"final_observation": np.array([3]),
"final_info": {"count": 3},
}
assert info == {"count": 3}
obs, reward, terminated, truncated, info = env.step(action)
assert obs == np.array([1])
assert obs == np.array([0])
assert reward == 0
assert (terminated or truncated) is False
assert info == {"count": 1}
assert info == {"count": 0}
env.close()

View File

@@ -1,21 +1,22 @@
"""Test suite for DictInfoTolist wrapper."""
from __future__ import annotations
from typing import Any
import numpy as np
import pytest
import gymnasium as gym
from gymnasium.wrappers.vector import DictInfoToList, RecordEpisodeStatistics
from gymnasium.core import ObsType
from gymnasium.spaces import Discrete
from gymnasium.utils.env_checker import data_equivalence
from gymnasium.vector import VectorEnv
from gymnasium.wrappers.vector import DictInfoToList
ENV_ID = "CartPole-v1"
NUM_ENVS = 3
ENV_STEPS = 50
SEED = 42
def test_usage_in_vector_env():
env = gym.make(ENV_ID, disable_env_checker=True)
vector_env = gym.make_vec(ENV_ID, num_envs=NUM_ENVS, vectorization_mode="sync")
def test_usage_in_vector_env(env_id: str = "CartPole-v1", num_envs: int = 3):
env = gym.make(env_id, disable_env_checker=True)
vector_env = gym.make_vec(env_id, num_envs=num_envs)
DictInfoToList(vector_env)
@@ -23,40 +24,140 @@ def test_usage_in_vector_env():
DictInfoToList(env)
def test_info_to_list():
env_to_wrap = gym.make_vec(ENV_ID, num_envs=NUM_ENVS, vectorization_mode="sync")
wrapped_env = DictInfoToList(env_to_wrap)
wrapped_env.action_space.seed(SEED)
_, info = wrapped_env.reset(seed=SEED)
assert isinstance(info, list)
assert len(info) == NUM_ENVS
class ResetOptionAsInfo(VectorEnv):
"""Minimal implementation to test the conversion of vector dict info to list info."""
for _ in range(ENV_STEPS):
action = wrapped_env.action_space.sample()
_, _, terminateds, truncateds, list_info = wrapped_env.step(action)
for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)):
if terminated or truncated:
assert "final_observation" in list_info[i]
else:
assert "final_observation" not in list_info[i]
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None, # options are passed are the info output
) -> tuple[ObsType, dict[str, Any]]:
return None, options
def test_info_to_list_statistics():
env_to_wrap = gym.make_vec(ENV_ID, num_envs=NUM_ENVS, vectorization_mode="sync")
wrapped_env = DictInfoToList(RecordEpisodeStatistics(env_to_wrap))
_, info = wrapped_env.reset(seed=SEED)
wrapped_env.action_space.seed(SEED)
assert isinstance(info, list)
assert len(info) == NUM_ENVS
def test_update_info():
env = DictInfoToList(ResetOptionAsInfo())
for _ in range(ENV_STEPS):
action = wrapped_env.action_space.sample()
_, _, terminateds, truncateds, list_info = wrapped_env.step(action)
for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)):
if terminated or truncated:
assert "episode" in list_info[i]
for stats in ["r", "l", "t"]:
assert stats in list_info[i]["episode"]
assert np.isscalar(list_info[i]["episode"][stats])
else:
assert "episode" not in list_info[i]
# Test num-envs==1 then expand_dims(sub-env-info) == vector-infos
env.unwrapped.num_envs = 1
vector_infos = {
"a": np.array([0]),
"b": np.array([0.0]),
"c": np.array([None], dtype=object),
"d": np.zeros(
(
1,
2,
)
),
"e": np.array([Discrete(1)], dtype=object),
"_a": np.array([True]),
"_b": np.array([True]),
"_c": np.array([True]),
"_d": np.array([True]),
"_e": np.array([True]),
}
_, list_info = env.reset(options=vector_infos)
expected_list_info = [
{
"a": np.int64(0),
"b": np.float64(0.0),
"c": None,
"d": np.zeros((2,)),
"e": Discrete(1),
}
]
assert data_equivalence(list_info, expected_list_info)
# Thought: num-envs>1 then vector-infos should have the same structure as sub-env-info
env.unwrapped.num_envs = 3
vector_infos = {
"a": np.array([0, 1, 2]),
"b": np.array([0.0, 1.0, 2.0]),
"c": np.array([None, None, None], dtype=object),
"d": np.zeros((3, 2)),
"e": np.array([Discrete(1), Discrete(2), Discrete(3)], dtype=object),
"_a": np.array([True, True, True]),
"_b": np.array([True, True, True]),
"_c": np.array([True, True, True]),
"_d": np.array([True, True, True]),
"_e": np.array([True, True, True]),
}
_, list_info = env.reset(options=vector_infos)
expected_list_info = [
{
"a": np.int64(0),
"b": np.float64(0.0),
"c": None,
"d": np.zeros((2,)),
"e": Discrete(1),
},
{
"a": np.int64(1),
"b": np.float64(1.0),
"c": None,
"d": np.zeros((2,)),
"e": Discrete(2),
},
{
"a": np.int64(2),
"b": np.float64(2.0),
"c": None,
"d": np.zeros((2,)),
"e": Discrete(3),
},
]
assert list_info[0].keys() == expected_list_info[0].keys()
for key in list_info[0].keys():
assert data_equivalence(list_info[0][key], expected_list_info[0][key])
assert data_equivalence(list_info, expected_list_info)
# Test different structures of sub-infos
env.unwrapped.num_envs = 3
vector_infos = {
"a": np.array([1, 0, 0]),
"_a": np.array([True, False, False]),
"b": np.array([1.0, 0.0, 0.0]),
"_b": np.array([True, False, False]),
"c": np.array([None, None, None], dtype=object),
"_c": np.array([False, True, False]),
"_d": np.array([False, True, False]),
"d": np.zeros((3, 2)),
"e": np.array([None, None, Discrete(3)], dtype=object),
"_e": np.array([False, False, True]),
}
_, list_info = env.reset(options=vector_infos)
expected_list_info = [
{"a": np.int64(1), "b": np.float64(1.0)},
{"c": None, "d": np.zeros((2,))},
{"e": Discrete(3)},
]
assert data_equivalence(list_info, expected_list_info)
# Test recursive structure
env.unwrapped.num_envs = 3
vector_infos = {
"episode": {
"a": np.array([1, 2, 0]),
"b": np.array([1.0, 2.0, 0.0]),
"_a": np.array([True, True, False]),
"_b": np.array([True, True, False]),
},
"_episode": np.array([True, True, False]),
"a": np.array([0, 1, 2]),
"_a": np.array([False, True, True]),
}
_, list_info = env.reset(options=vector_infos)
expected_list_info = [
{"episode": {"a": np.int64(1), "b": np.float64(1.0)}},
{"episode": {"a": np.int64(2), "b": np.float64(2.0)}, "a": np.int64(1)},
{"a": np.int64(2)},
]
assert data_equivalence(list_info, expected_list_info)

View File

@@ -0,0 +1,71 @@
import pytest
import gymnasium as gym
from gymnasium.utils.env_checker import data_equivalence
from gymnasium.vector import VectorEnv
@pytest.mark.parametrize("num_envs", (1, 3))
def test_record_episode_statistics(num_envs, env_id="CartPole-v1", num_steps=100):
wrapper_vector_env: VectorEnv = gym.wrappers.vector.RecordEpisodeStatistics(
gym.make_vec(id=env_id, num_envs=num_envs, vectorization_mode="sync"),
)
vector_wrapper_env = gym.make_vec(
id=env_id,
num_envs=num_envs,
vectorization_mode="sync",
wrappers=(gym.wrappers.RecordEpisodeStatistics,),
)
assert wrapper_vector_env.action_space == vector_wrapper_env.action_space
assert wrapper_vector_env.observation_space == vector_wrapper_env.observation_space
assert (
wrapper_vector_env.single_action_space == vector_wrapper_env.single_action_space
)
assert (
wrapper_vector_env.single_observation_space
== vector_wrapper_env.single_observation_space
)
assert wrapper_vector_env.num_envs == vector_wrapper_env.num_envs
wrapper_vector_obs, wrapper_vector_info = wrapper_vector_env.reset(seed=123)
vector_wrapper_obs, vector_wrapper_info = vector_wrapper_env.reset(seed=123)
assert data_equivalence(wrapper_vector_obs, vector_wrapper_obs)
assert data_equivalence(wrapper_vector_info, vector_wrapper_info)
for _ in range(num_steps):
action = wrapper_vector_env.action_space.sample()
(
wrapper_vector_obs,
wrapper_vector_reward,
wrapper_vector_terminated,
wrapper_vector_truncated,
wrapper_vector_info,
) = wrapper_vector_env.step(action)
(
vector_wrapper_obs,
vector_wrapper_reward,
vector_wrapper_terminated,
vector_wrapper_truncated,
vector_wrapper_info,
) = vector_wrapper_env.step(action)
data_equivalence(wrapper_vector_obs, vector_wrapper_obs)
data_equivalence(wrapper_vector_reward, vector_wrapper_reward)
data_equivalence(wrapper_vector_terminated, vector_wrapper_terminated)
data_equivalence(wrapper_vector_truncated, vector_wrapper_truncated)
if "episode" in wrapper_vector_info:
assert "episode" in vector_wrapper_info
wrapper_vector_time = wrapper_vector_info["episode"].pop("t")
vector_wrapper_time = vector_wrapper_info["episode"].pop("t")
assert wrapper_vector_time.shape == vector_wrapper_time.shape
assert wrapper_vector_time.dtype == vector_wrapper_time.dtype
data_equivalence(wrapper_vector_info, vector_wrapper_info)
wrapper_vector_env.close()
vector_wrapper_env.close()

View File

@@ -42,21 +42,21 @@ def custom_environments():
("CustomDictEnv-v0", "FilterObservation", {"filter_keys": ["a"]}),
("CartPole-v1", "FlattenObservation", {}),
("CarRacing-v2", "GrayscaleObservation", {}),
# ("CarRacing-v2", "ResizeObservation", {"shape": (35, 45)}),
("CarRacing-v2", "ResizeObservation", {"shape": (35, 45)}),
("CarRacing-v2", "ReshapeObservation", {"shape": (96, 48, 6)}),
("CartPole-v1", "RescaleObservation", {"min_obs": 0, "max_obs": 1}),
("CartPole-v1", "DtypeObservation", {"dtype": np.int32}),
# ("CartPole-v1", "RenderObservation", {}),
# ("CartPole-v1", "TimeAwareObservation", {}),
# ("CartPole-v1", "FrameStackObservation", {}),
# ("CartPole-v1", "DelayObservation", {}),
# ("CartPole-v1", "RenderObservation", {}), # not implemented
# ("CartPole-v1", "TimeAwareObservation", {}), # not implemented
# ("CartPole-v1", "FrameStackObservation", {}), # not implemented
# ("CartPole-v1", "DelayObservation", {}), # not implemented
("MountainCarContinuous-v0", "ClipAction", {}),
(
"MountainCarContinuous-v0",
"RescaleAction",
{"min_action": 1, "max_action": 2},
),
("CartPole-v1", "ClipReward", {"min_reward": 0.25, "max_reward": 0.75}),
("CartPole-v1", "ClipReward", {"min_reward": -0.25, "max_reward": 0.75}),
),
)
def test_vector_wrapper_equivalence(