mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-30 17:57:30 +00:00
Change autoreset order (#808)
Co-authored-by: pseudo-rnd-thoughts <mark.m.towers@gmail.com>
This commit is contained in:
committed by
GitHub
parent
967bbf5823
commit
e9c66e4225
@@ -151,6 +151,8 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
|
|||||||
|
|
||||||
self.steps = jnp.zeros(self.num_envs, dtype=jnp.int32)
|
self.steps = jnp.zeros(self.num_envs, dtype=jnp.int32)
|
||||||
|
|
||||||
|
self.autoreset_envs = jnp.zeros(self.num_envs, dtype=jnp.bool_)
|
||||||
|
|
||||||
self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
|
self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
|
||||||
|
|
||||||
if self.render_mode == "rgb_array":
|
if self.render_mode == "rgb_array":
|
||||||
@@ -214,10 +216,9 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
|
|||||||
info = self.func_env.transition_info(self.state, action, next_state)
|
info = self.func_env.transition_info(self.state, action, next_state)
|
||||||
|
|
||||||
done = jnp.logical_or(terminated, truncated)
|
done = jnp.logical_or(terminated, truncated)
|
||||||
if jnp.any(done):
|
|
||||||
final_obs = self.func_env.observation(next_state)
|
|
||||||
|
|
||||||
to_reset = jnp.where(done)[0]
|
if jnp.any(self.autoreset_envs):
|
||||||
|
to_reset = jnp.where(self.autoreset_envs)[0]
|
||||||
reset_count = to_reset.shape[0]
|
reset_count = to_reset.shape[0]
|
||||||
|
|
||||||
rng, self.rng = jrng.split(self.rng)
|
rng, self.rng = jrng.split(self.rng)
|
||||||
@@ -228,34 +229,16 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
|
|||||||
next_state = self.state.at[to_reset].set(new_initials)
|
next_state = self.state.at[to_reset].set(new_initials)
|
||||||
self.steps = self.steps.at[to_reset].set(0)
|
self.steps = self.steps.at[to_reset].set(0)
|
||||||
|
|
||||||
# Get the final observations and infos
|
self.autoreset_envs = done
|
||||||
info["final_observation"] = np.array([None for _ in range(self.num_envs)])
|
|
||||||
info["final_info"] = np.array([None for _ in range(self.num_envs)])
|
|
||||||
|
|
||||||
info["_final_observation"] = np.array([False for _ in range(self.num_envs)])
|
|
||||||
info["_final_info"] = np.array([False for _ in range(self.num_envs)])
|
|
||||||
|
|
||||||
# TODO: this can maybe be optimized, but right now I don't know how
|
|
||||||
for i in to_reset:
|
|
||||||
info["final_observation"][i] = final_obs[i]
|
|
||||||
info["final_info"][i] = {
|
|
||||||
k: v[i]
|
|
||||||
for k, v in info.items()
|
|
||||||
if k
|
|
||||||
not in {
|
|
||||||
"final_observation",
|
|
||||||
"final_info",
|
|
||||||
"_final_observation",
|
|
||||||
"_final_info",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
info["_final_observation"][i] = True
|
|
||||||
info["_final_info"][i] = True
|
|
||||||
|
|
||||||
observation = self.func_env.observation(next_state)
|
observation = self.func_env.observation(next_state)
|
||||||
observation = jax_to_numpy(observation)
|
observation = jax_to_numpy(observation)
|
||||||
|
|
||||||
|
reward = jax_to_numpy(reward)
|
||||||
|
|
||||||
|
terminated = jax_to_numpy(terminated)
|
||||||
|
truncated = jax_to_numpy(truncated)
|
||||||
|
|
||||||
self.state = next_state
|
self.state = next_state
|
||||||
|
|
||||||
return observation, reward, terminated, truncated, info
|
return observation, reward, terminated, truncated, info
|
||||||
|
@@ -639,6 +639,7 @@ def _async_worker(
|
|||||||
env = env_fn()
|
env = env_fn()
|
||||||
observation_space = env.observation_space
|
observation_space = env.observation_space
|
||||||
action_space = env.action_space
|
action_space = env.action_space
|
||||||
|
autoreset = False
|
||||||
|
|
||||||
parent_pipe.close()
|
parent_pipe.close()
|
||||||
|
|
||||||
@@ -653,20 +654,21 @@ def _async_worker(
|
|||||||
observation_space, index, observation, shared_memory
|
observation_space, index, observation, shared_memory
|
||||||
)
|
)
|
||||||
observation = None
|
observation = None
|
||||||
|
autoreset = False
|
||||||
pipe.send(((observation, info), True))
|
pipe.send(((observation, info), True))
|
||||||
elif command == "step":
|
elif command == "step":
|
||||||
(
|
if autoreset:
|
||||||
observation,
|
|
||||||
reward,
|
|
||||||
terminated,
|
|
||||||
truncated,
|
|
||||||
info,
|
|
||||||
) = env.step(data)
|
|
||||||
if terminated or truncated:
|
|
||||||
old_observation, old_info = observation, info
|
|
||||||
observation, info = env.reset()
|
observation, info = env.reset()
|
||||||
info["final_observation"] = old_observation
|
reward, terminated, truncated = 0, False, False
|
||||||
info["final_info"] = old_info
|
else:
|
||||||
|
(
|
||||||
|
observation,
|
||||||
|
reward,
|
||||||
|
terminated,
|
||||||
|
truncated,
|
||||||
|
info,
|
||||||
|
) = env.step(data)
|
||||||
|
autoreset = terminated or truncated
|
||||||
|
|
||||||
if shared_memory:
|
if shared_memory:
|
||||||
write_to_shared_memory(
|
write_to_shared_memory(
|
||||||
|
@@ -98,6 +98,8 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
self._terminations = np.zeros((self.num_envs,), dtype=np.bool_)
|
self._terminations = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||||
self._truncations = np.zeros((self.num_envs,), dtype=np.bool_)
|
self._truncations = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||||
|
|
||||||
|
self._autoreset_envs = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||||
|
|
||||||
def reset(
|
def reset(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -150,22 +152,21 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
actions = iterate(self.action_space, actions)
|
actions = iterate(self.action_space, actions)
|
||||||
|
|
||||||
observations, infos = [], {}
|
observations, infos = [], {}
|
||||||
for i, (env, action) in enumerate(zip(self.envs, actions)):
|
for i, action in enumerate(actions):
|
||||||
(
|
if self._autoreset_envs[i]:
|
||||||
env_obs,
|
env_obs, env_info = self.envs[i].reset()
|
||||||
self._rewards[i],
|
|
||||||
self._terminations[i],
|
|
||||||
self._truncations[i],
|
|
||||||
env_info,
|
|
||||||
) = env.step(action)
|
|
||||||
|
|
||||||
# If sub-environments terminates or truncates then save the obs and info to the batched info
|
self._rewards[i] = 0.0
|
||||||
if self._terminations[i] or self._truncations[i]:
|
self._terminations[i] = False
|
||||||
old_observation, old_info = env_obs, env_info
|
self._truncations[i] = False
|
||||||
env_obs, env_info = env.reset()
|
else:
|
||||||
|
(
|
||||||
env_info["final_observation"] = old_observation
|
env_obs,
|
||||||
env_info["final_info"] = old_info
|
self._rewards[i],
|
||||||
|
self._terminations[i],
|
||||||
|
self._truncations[i],
|
||||||
|
env_info,
|
||||||
|
) = self.envs[i].step(action)
|
||||||
|
|
||||||
observations.append(env_obs)
|
observations.append(env_obs)
|
||||||
infos = self._add_info(infos, env_info, i)
|
infos = self._add_info(infos, env_info, i)
|
||||||
@@ -174,6 +175,7 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
self._observations = concatenate(
|
self._observations = concatenate(
|
||||||
self.single_observation_space, observations, self._observations
|
self.single_observation_space, observations, self._observations
|
||||||
)
|
)
|
||||||
|
self._autoreset_envs = np.logical_or(self._terminations, self._truncations)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
deepcopy(self._observations) if self.copy else self._observations,
|
deepcopy(self._observations) if self.copy else self._observations,
|
||||||
|
@@ -231,7 +231,7 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def _add_info(
|
def _add_info(
|
||||||
self, infos: dict[str, Any], info: dict[str, Any], env_num: int
|
self, vector_infos: dict[str, Any], env_info: dict[str, Any], env_num: int
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Add env info to the info dictionary of the vectorized environment.
|
"""Add env info to the info dictionary of the vectorized environment.
|
||||||
|
|
||||||
@@ -241,48 +241,51 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
|
|||||||
whether or not the i-indexed environment has this `info`.
|
whether or not the i-indexed environment has this `info`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
infos (dict): the infos of the vectorized environment
|
vector_infos (dict): the infos of the vectorized environment
|
||||||
info (dict): the info coming from the single environment
|
env_info (dict): the info coming from the single environment
|
||||||
env_num (int): the index of the single environment
|
env_num (int): the index of the single environment
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
infos (dict): the (updated) infos of the vectorized environment
|
infos (dict): the (updated) infos of the vectorized environment
|
||||||
|
|
||||||
"""
|
"""
|
||||||
for k in info.keys():
|
for key, value in env_info.items():
|
||||||
if k not in infos:
|
# If value is a dictionary, then we apply the `_add_info` recursively.
|
||||||
info_array, array_mask = self._init_info_arrays(type(info[k]))
|
if isinstance(value, dict):
|
||||||
|
array = self._add_info(vector_infos.get(key, {}), value, env_num)
|
||||||
|
# Otherwise, we are a base case to group the data
|
||||||
else:
|
else:
|
||||||
info_array, array_mask = infos[k], infos[f"_{k}"]
|
# If the key doesn't exist in the vector infos, then we can create an array of that batch type
|
||||||
|
if key not in vector_infos:
|
||||||
|
if type(value) in [int, float, bool] or issubclass(
|
||||||
|
type(value), np.number
|
||||||
|
):
|
||||||
|
array = np.zeros(self.num_envs, dtype=type(value))
|
||||||
|
elif isinstance(value, np.ndarray):
|
||||||
|
# We assume that all instances of the np.array info are of the same shape
|
||||||
|
array = np.zeros(
|
||||||
|
(self.num_envs, *value.shape), dtype=value.dtype
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For unknown objects, we use a Numpy object array
|
||||||
|
array = np.full(self.num_envs, fill_value=None, dtype=object)
|
||||||
|
# Otherwise, just use the array that already exists
|
||||||
|
else:
|
||||||
|
array = vector_infos[key]
|
||||||
|
|
||||||
info_array[env_num], array_mask[env_num] = info[k], True
|
# Assign the data in the `env_num` position
|
||||||
infos[k], infos[f"_{k}"] = info_array, array_mask
|
# We only want to run this for the base-case data (not recursive data forcing the ugly function structure)
|
||||||
return infos
|
array[env_num] = value
|
||||||
|
|
||||||
def _init_info_arrays(self, dtype: type) -> tuple[np.ndarray, np.ndarray]:
|
# Get the array mask and if it doesn't already exist then create a zero bool array
|
||||||
"""Initialize the info array.
|
array_mask = vector_infos.get(
|
||||||
|
f"_{key}", np.zeros(self.num_envs, dtype=np.bool_)
|
||||||
|
)
|
||||||
|
array_mask[env_num] = True
|
||||||
|
|
||||||
Initialize the info array. If the dtype is numeric
|
# Update the vector info with the updated data and mask information
|
||||||
the info array will have the same dtype, otherwise
|
vector_infos[key], vector_infos[f"_{key}"] = array, array_mask
|
||||||
will be an array of `None`. Also, a boolean array
|
|
||||||
of the same length is returned. It will be used for
|
|
||||||
assessing which environment has info data.
|
|
||||||
|
|
||||||
Args:
|
return vector_infos
|
||||||
dtype (type): data type of the info coming from the env.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
array (np.ndarray): the initialized info array.
|
|
||||||
array_mask (np.ndarray): the initialized boolean array.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if dtype in [int, float, bool] or issubclass(dtype, np.number):
|
|
||||||
array = np.zeros(self.num_envs, dtype=dtype)
|
|
||||||
else:
|
|
||||||
array = np.zeros(self.num_envs, dtype=object)
|
|
||||||
array[:] = None
|
|
||||||
array_mask = np.zeros(self.num_envs, dtype=bool)
|
|
||||||
return array, array_mask
|
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
"""Closes the vector environment."""
|
"""Closes the vector environment."""
|
||||||
@@ -441,23 +444,23 @@ class VectorObservationWrapper(VectorWrapper):
|
|||||||
options: dict[str, Any] | None = None,
|
options: dict[str, Any] | None = None,
|
||||||
) -> tuple[ObsType, dict[str, Any]]:
|
) -> tuple[ObsType, dict[str, Any]]:
|
||||||
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
|
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
|
||||||
obs, info = self.env.reset(seed=seed, options=options)
|
observations, infos = self.env.reset(seed=seed, options=options)
|
||||||
return self.vector_observation(obs), info
|
return self.observation(observations), infos
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
self, actions: ActType
|
self, actions: ActType
|
||||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
|
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
|
||||||
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
|
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
|
||||||
observation, reward, termination, truncation, info = self.env.step(actions)
|
observations, rewards, terminations, truncations, infos = self.env.step(actions)
|
||||||
return (
|
return (
|
||||||
self.vector_observation(observation),
|
self.observation(observations),
|
||||||
reward,
|
rewards,
|
||||||
termination,
|
terminations,
|
||||||
truncation,
|
truncations,
|
||||||
self.update_final_obs(info),
|
infos,
|
||||||
)
|
)
|
||||||
|
|
||||||
def vector_observation(self, observation: ObsType) -> ObsType:
|
def observation(self, observation: ObsType) -> ObsType:
|
||||||
"""Defines the vector observation transformation.
|
"""Defines the vector observation transformation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -468,25 +471,6 @@ class VectorObservationWrapper(VectorWrapper):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def single_observation(self, observation: ObsType) -> ObsType:
|
|
||||||
"""Defines the single observation transformation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
observation: A single observation from the environment
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The transformed observation
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def update_final_obs(self, info: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Updates the `final_obs` in the info using `single_observation`."""
|
|
||||||
if "final_observation" in info:
|
|
||||||
for i, obs in enumerate(info["final_observation"]):
|
|
||||||
if obs is not None:
|
|
||||||
info["final_observation"][i] = self.single_observation(obs)
|
|
||||||
return info
|
|
||||||
|
|
||||||
|
|
||||||
class VectorActionWrapper(VectorWrapper):
|
class VectorActionWrapper(VectorWrapper):
|
||||||
"""Wraps the vectorized environment to allow a modular transformation of the actions.
|
"""Wraps the vectorized environment to allow a modular transformation of the actions.
|
||||||
@@ -522,14 +506,14 @@ class VectorRewardWrapper(VectorWrapper):
|
|||||||
self, actions: ActType
|
self, actions: ActType
|
||||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
|
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
|
||||||
"""Steps through the environment returning a reward modified by :meth:`reward`."""
|
"""Steps through the environment returning a reward modified by :meth:`reward`."""
|
||||||
observation, reward, termination, truncation, info = self.env.step(actions)
|
observations, rewards, terminations, truncations, infos = self.env.step(actions)
|
||||||
return observation, self.rewards(reward), termination, truncation, info
|
return observations, self.rewards(rewards), terminations, truncations, infos
|
||||||
|
|
||||||
def rewards(self, reward: ArrayType) -> ArrayType:
|
def rewards(self, rewards: ArrayType) -> ArrayType:
|
||||||
"""Transform the reward before returning it.
|
"""Transform the reward before returning it.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
reward (array): the reward to transform
|
rewards (array): the reward to transform
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: the transformed reward
|
array: the transformed reward
|
||||||
|
@@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Any, SupportsFloat
|
|||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium import logger
|
from gymnasium import logger
|
||||||
from gymnasium.core import ActType, ObsType, RenderFrame
|
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperObsType
|
||||||
from gymnasium.error import ResetNeeded
|
from gymnasium.error import ResetNeeded
|
||||||
from gymnasium.utils.passive_env_checker import (
|
from gymnasium.utils.passive_env_checker import (
|
||||||
check_action_space,
|
check_action_space,
|
||||||
@@ -196,6 +196,15 @@ class Autoreset(
|
|||||||
gym.utils.RecordConstructorArgs.__init__(self)
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
gym.Wrapper.__init__(self, env)
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
|
self.autoreset = False
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||||
|
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||||
|
"""Resets the environment and sets autoreset to False preventing."""
|
||||||
|
self.autoreset = False
|
||||||
|
return super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
self, action: ActType
|
self, action: ActType
|
||||||
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||||
@@ -207,24 +216,13 @@ class Autoreset(
|
|||||||
Returns:
|
Returns:
|
||||||
The autoreset environment :meth:`step`
|
The autoreset environment :meth:`step`
|
||||||
"""
|
"""
|
||||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
if self.autoreset:
|
||||||
|
obs, info = self.env.reset()
|
||||||
if terminated or truncated:
|
reward, terminated, truncated = 0.0, False, False
|
||||||
new_obs, new_info = self.env.reset()
|
else:
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
assert (
|
|
||||||
"final_observation" not in new_info
|
|
||||||
), f'new info dict already contains "final_observation", info keys: {new_info.keys()}'
|
|
||||||
assert (
|
|
||||||
"final_info" not in new_info
|
|
||||||
), f'new info dict already contains "final_observation", info keys: {new_info.keys()}'
|
|
||||||
|
|
||||||
new_info["final_observation"] = obs
|
|
||||||
new_info["final_info"] = info
|
|
||||||
|
|
||||||
obs = new_obs
|
|
||||||
info = new_info
|
|
||||||
|
|
||||||
|
self.autoreset = terminated or truncated
|
||||||
return obs, reward, terminated, truncated, info
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
|
||||||
@@ -470,14 +468,14 @@ class RecordEpisodeStatistics(
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env: gym.Env[ObsType, ActType],
|
env: gym.Env[ObsType, ActType],
|
||||||
buffer_length: int | None = 100,
|
buffer_length: int = 100,
|
||||||
stats_key: str = "episode",
|
stats_key: str = "episode",
|
||||||
):
|
):
|
||||||
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env (Env): The environment to apply the wrapper
|
env (Env): The environment to apply the wrapper
|
||||||
buffer_length: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
|
buffer_length: The size of the buffers :attr:`return_queue`, :attr:`length_queue` and :attr:`time_queue`
|
||||||
stats_key: The info key for the episode statistics
|
stats_key: The info key for the episode statistics
|
||||||
"""
|
"""
|
||||||
gym.utils.RecordConstructorArgs.__init__(self)
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
@@ -520,6 +518,7 @@ class RecordEpisodeStatistics(
|
|||||||
self.length_queue.append(self.episode_lengths)
|
self.length_queue.append(self.episode_lengths)
|
||||||
|
|
||||||
self.episode_count += 1
|
self.episode_count += 1
|
||||||
|
self.episode_start_time = time.perf_counter()
|
||||||
|
|
||||||
return obs, reward, terminated, truncated, info
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
@@ -444,7 +444,7 @@ class NormalizeObservation(
|
|||||||
|
|
||||||
Change logs:
|
Change logs:
|
||||||
* v0.21.0 - Initially add
|
* v0.21.0 - Initially add
|
||||||
* v1.0.0 - Add `update_running_mean` attribute to allow disabling of updating the running mean / standard
|
* v1.0.0 - Add `update_running_mean` attribute to allow disabling of updating the running mean / standard, particularly useful for evaluation time.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env: gym.Env[ObsType, ActType], epsilon: float = 1e-8):
|
def __init__(self, env: gym.Env[ObsType, ActType], epsilon: float = 1e-8):
|
||||||
|
@@ -62,14 +62,21 @@ class RecordEpisodeStatistics(VectorWrapper):
|
|||||||
None, None], dtype=object)}
|
None, None], dtype=object)}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env: VectorEnv, deque_size: int = 100):
|
def __init__(
|
||||||
|
self,
|
||||||
|
env: VectorEnv,
|
||||||
|
deque_size: int = 100,
|
||||||
|
stats_key: str = "episode",
|
||||||
|
):
|
||||||
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
"""This wrapper will keep track of cumulative rewards and episode lengths.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env (Env): The environment to apply the wrapper
|
env (Env): The environment to apply the wrapper
|
||||||
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
|
deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue`
|
||||||
|
stats_key: The info key to save the data
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
self._stats_key = stats_key
|
||||||
|
|
||||||
self.episode_count = 0
|
self.episode_count = 0
|
||||||
|
|
||||||
@@ -77,6 +84,7 @@ class RecordEpisodeStatistics(VectorWrapper):
|
|||||||
self.episode_returns: np.ndarray = np.zeros(())
|
self.episode_returns: np.ndarray = np.zeros(())
|
||||||
self.episode_lengths: np.ndarray = np.zeros(())
|
self.episode_lengths: np.ndarray = np.zeros(())
|
||||||
|
|
||||||
|
self.time_queue = deque(maxlen=deque_size)
|
||||||
self.return_queue = deque(maxlen=deque_size)
|
self.return_queue = deque(maxlen=deque_size)
|
||||||
self.length_queue = deque(maxlen=deque_size)
|
self.length_queue = deque(maxlen=deque_size)
|
||||||
|
|
||||||
@@ -88,11 +96,9 @@ class RecordEpisodeStatistics(VectorWrapper):
|
|||||||
"""Resets the environment using kwargs and resets the episode returns and lengths."""
|
"""Resets the environment using kwargs and resets the episode returns and lengths."""
|
||||||
obs, info = super().reset(seed=seed, options=options)
|
obs, info = super().reset(seed=seed, options=options)
|
||||||
|
|
||||||
self.episode_start_times = np.full(
|
self.episode_start_times = np.full(self.num_envs, time.perf_counter())
|
||||||
self.num_envs, time.perf_counter(), dtype=np.float32
|
self.episode_returns = np.zeros(self.num_envs)
|
||||||
)
|
self.episode_lengths = np.zeros(self.num_envs)
|
||||||
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
|
|
||||||
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
|
|
||||||
|
|
||||||
return obs, info
|
return obs, info
|
||||||
|
|
||||||
@@ -110,7 +116,7 @@ class RecordEpisodeStatistics(VectorWrapper):
|
|||||||
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
infos, dict
|
infos, dict
|
||||||
), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order."
|
), f"`vector.RecordEpisodeStatistics` requires `info` type to be `dict`, its actual type is {type(infos)}. This may be due to usage of other wrappers in the wrong order."
|
||||||
|
|
||||||
self.episode_returns += rewards
|
self.episode_returns += rewards
|
||||||
self.episode_lengths += 1
|
self.episode_lengths += 1
|
||||||
@@ -119,25 +125,25 @@ class RecordEpisodeStatistics(VectorWrapper):
|
|||||||
num_dones = np.sum(dones)
|
num_dones = np.sum(dones)
|
||||||
|
|
||||||
if num_dones:
|
if num_dones:
|
||||||
if "episode" in infos or "_episode" in infos:
|
if self._stats_key in infos or f"_{self._stats_key}" in infos:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Attempted to add episode stats when they already exist"
|
f"Attempted to add episode stats when they already exist, info keys: {list(infos.keys())}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
infos["episode"] = {
|
episode_time_length = np.round(
|
||||||
|
time.perf_counter() - self.episode_start_times, 6
|
||||||
|
)
|
||||||
|
infos[self._stats_key] = {
|
||||||
"r": np.where(dones, self.episode_returns, 0.0),
|
"r": np.where(dones, self.episode_returns, 0.0),
|
||||||
"l": np.where(dones, self.episode_lengths, 0),
|
"l": np.where(dones, self.episode_lengths, 0),
|
||||||
"t": np.where(
|
"t": np.where(dones, episode_time_length, 0.0),
|
||||||
dones,
|
|
||||||
np.round(time.perf_counter() - self.episode_start_times, 6),
|
|
||||||
0.0,
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
infos["_episode"] = dones
|
infos[f"_{self._stats_key}"] = dones
|
||||||
|
|
||||||
self.episode_count += num_dones
|
self.episode_count += num_dones
|
||||||
|
|
||||||
for i in np.where(dones):
|
for i in np.where(dones):
|
||||||
|
self.time_queue.extend(episode_time_length[i])
|
||||||
self.return_queue.extend(self.episode_returns[i])
|
self.return_queue.extend(self.episode_returns[i])
|
||||||
self.length_queue.extend(self.episode_lengths[i])
|
self.length_queue.extend(self.episode_lengths[i])
|
||||||
|
|
||||||
|
@@ -3,6 +3,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from gymnasium.core import ActType, ObsType
|
from gymnasium.core import ActType, ObsType
|
||||||
from gymnasium.vector.vector_env import ArrayType, VectorEnv, VectorWrapper
|
from gymnasium.vector.vector_env import ArrayType, VectorEnv, VectorWrapper
|
||||||
|
|
||||||
@@ -78,6 +80,7 @@ class DictInfoToList(VectorWrapper):
|
|||||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, list[dict[str, Any]]]:
|
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, list[dict[str, Any]]]:
|
||||||
"""Steps through the environment, convert dict info to list."""
|
"""Steps through the environment, convert dict info to list."""
|
||||||
observation, reward, terminated, truncated, infos = self.env.step(actions)
|
observation, reward, terminated, truncated, infos = self.env.step(actions)
|
||||||
|
assert isinstance(infos, dict)
|
||||||
list_info = self._convert_info_to_list(infos)
|
list_info = self._convert_info_to_list(infos)
|
||||||
|
|
||||||
return observation, reward, terminated, truncated, list_info
|
return observation, reward, terminated, truncated, list_info
|
||||||
@@ -90,11 +93,12 @@ class DictInfoToList(VectorWrapper):
|
|||||||
) -> tuple[ObsType, list[dict[str, Any]]]:
|
) -> tuple[ObsType, list[dict[str, Any]]]:
|
||||||
"""Resets the environment using kwargs."""
|
"""Resets the environment using kwargs."""
|
||||||
obs, infos = self.env.reset(seed=seed, options=options)
|
obs, infos = self.env.reset(seed=seed, options=options)
|
||||||
|
assert isinstance(infos, dict)
|
||||||
list_info = self._convert_info_to_list(infos)
|
list_info = self._convert_info_to_list(infos)
|
||||||
|
|
||||||
return obs, list_info
|
return obs, list_info
|
||||||
|
|
||||||
def _convert_info_to_list(self, infos: dict) -> list[dict[str, Any]]:
|
def _convert_info_to_list(self, vector_infos: dict) -> list[dict[str, Any]]:
|
||||||
"""Convert the dict info to list.
|
"""Convert the dict info to list.
|
||||||
|
|
||||||
Convert the dict info of the vectorized environment
|
Convert the dict info of the vectorized environment
|
||||||
@@ -102,52 +106,28 @@ class DictInfoToList(VectorWrapper):
|
|||||||
has the info of the i-th environment.
|
has the info of the i-th environment.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
infos (dict): info dict coming from the env.
|
vector_infos (dict): info dict coming from the env.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list_info (list): converted info.
|
list_info (list): converted info.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
list_info = [{} for _ in range(self.num_envs)]
|
list_info = [{} for _ in range(self.num_envs)]
|
||||||
list_info = self._process_episode_statistics(infos, list_info)
|
|
||||||
for k in infos:
|
for key, value in vector_infos.items():
|
||||||
if k.startswith("_"):
|
if key.startswith("_"):
|
||||||
continue
|
continue
|
||||||
for i, has_info in enumerate(infos[f"_{k}"]):
|
|
||||||
if has_info:
|
|
||||||
list_info[i][k] = infos[k][i]
|
|
||||||
return list_info
|
|
||||||
|
|
||||||
# todo - I think this function should be more general for any information
|
if isinstance(value, dict):
|
||||||
def _process_episode_statistics(self, infos: dict, list_info: list) -> list[dict]:
|
value_list_info = self._convert_info_to_list(value)
|
||||||
"""Process episode statistics.
|
for env_num, (env_info, has_info) in enumerate(
|
||||||
|
zip(value_list_info, vector_infos[f"_{key}"])
|
||||||
`RecordEpisodeStatistics` wrapper add extra
|
):
|
||||||
information to the info. This information are in
|
if has_info:
|
||||||
the form of a dict of dict. This method process these
|
list_info[env_num][key] = env_info
|
||||||
information and add them to the info.
|
else:
|
||||||
`RecordEpisodeStatistics` info contains the keys
|
assert isinstance(value, np.ndarray)
|
||||||
"r", "l", "t" which represents "cumulative reward",
|
for env_num, has_info in enumerate(vector_infos[f"_{key}"]):
|
||||||
"episode length", "elapsed time since instantiation of wrapper".
|
if has_info:
|
||||||
|
list_info[env_num][key] = value[env_num]
|
||||||
Args:
|
|
||||||
infos (dict): infos coming from `RecordEpisodeStatistics`.
|
|
||||||
list_info (list): info of the current vectorized environment.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list_info (list): updated info.
|
|
||||||
|
|
||||||
"""
|
|
||||||
episode_statistics = infos.pop("episode", False)
|
|
||||||
if not episode_statistics:
|
|
||||||
return list_info
|
|
||||||
|
|
||||||
episode_statistics_mask = infos.pop("_episode")
|
|
||||||
for i, has_info in enumerate(episode_statistics_mask):
|
|
||||||
if has_info:
|
|
||||||
list_info[i]["episode"] = {}
|
|
||||||
list_info[i]["episode"]["r"] = episode_statistics["r"][i]
|
|
||||||
list_info[i]["episode"]["l"] = episode_statistics["l"][i]
|
|
||||||
list_info[i]["episode"]["t"] = episode_statistics["t"][i]
|
|
||||||
|
|
||||||
return list_info
|
return list_info
|
||||||
|
@@ -34,9 +34,9 @@ class NormalizeObservation(VectorObservationWrapper, gym.utils.RecordConstructor
|
|||||||
>>> for _ in range(100):
|
>>> for _ in range(100):
|
||||||
... obs, *_ = envs.step(envs.action_space.sample())
|
... obs, *_ = envs.step(envs.action_space.sample())
|
||||||
>>> np.mean(obs)
|
>>> np.mean(obs)
|
||||||
-0.017698428
|
0.024251968
|
||||||
>>> np.std(obs)
|
>>> np.std(obs)
|
||||||
0.62041104
|
0.62259156
|
||||||
>>> envs.close()
|
>>> envs.close()
|
||||||
|
|
||||||
Example with the normalize reward wrapper:
|
Example with the normalize reward wrapper:
|
||||||
@@ -48,9 +48,9 @@ class NormalizeObservation(VectorObservationWrapper, gym.utils.RecordConstructor
|
|||||||
>>> for _ in range(100):
|
>>> for _ in range(100):
|
||||||
... obs, *_ = envs.step(envs.action_space.sample())
|
... obs, *_ = envs.step(envs.action_space.sample())
|
||||||
>>> np.mean(obs)
|
>>> np.mean(obs)
|
||||||
-0.28381696
|
-0.2359734
|
||||||
>>> np.std(obs)
|
>>> np.std(obs)
|
||||||
1.21742
|
1.1938739
|
||||||
>>> envs.close()
|
>>> envs.close()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -81,29 +81,15 @@ class NormalizeObservation(VectorObservationWrapper, gym.utils.RecordConstructor
|
|||||||
"""Sets the property to freeze/continue the running mean calculation of the observation statistics."""
|
"""Sets the property to freeze/continue the running mean calculation of the observation statistics."""
|
||||||
self._update_running_mean = setting
|
self._update_running_mean = setting
|
||||||
|
|
||||||
def vector_observation(self, observation: ObsType) -> ObsType:
|
def observation(self, observations: ObsType) -> ObsType:
|
||||||
"""Defines the vector observation normalization function.
|
"""Defines the vector observation normalization function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
observation: A vector observation from the environment
|
observations: A vector observation from the environment
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
the normalized observation
|
the normalized observation
|
||||||
"""
|
"""
|
||||||
return self._normalize_observations(observation)
|
|
||||||
|
|
||||||
def single_observation(self, observation: ObsType) -> ObsType:
|
|
||||||
"""Defines the single observation normalization function.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
observation: A single observation from the environment
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The normalized observation
|
|
||||||
"""
|
|
||||||
return self._normalize_observations(observation[None])
|
|
||||||
|
|
||||||
def _normalize_observations(self, observations: ObsType) -> ObsType:
|
|
||||||
if self._update_running_mean:
|
if self._update_running_mean:
|
||||||
self.obs_rms.update(observations)
|
self.obs_rms.update(observations)
|
||||||
return (observations - self.obs_rms.mean) / np.sqrt(
|
return (observations - self.obs_rms.mean) / np.sqrt(
|
||||||
|
@@ -37,13 +37,10 @@ class TransformObservation(VectorObservationWrapper):
|
|||||||
>>> def scale_and_shift(obs):
|
>>> def scale_and_shift(obs):
|
||||||
... return (obs - 1.0) * 2.0
|
... return (obs - 1.0) * 2.0
|
||||||
...
|
...
|
||||||
>>> def vector_scale_and_shift(obs):
|
|
||||||
... return (obs - 1.0) * 2.0
|
|
||||||
...
|
|
||||||
>>> import gymnasium as gym
|
>>> import gymnasium as gym
|
||||||
>>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
|
>>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync")
|
||||||
>>> new_obs_space = Box(low=envs.observation_space.low, high=envs.observation_space.high)
|
>>> new_obs_space = Box(low=envs.observation_space.low, high=envs.observation_space.high)
|
||||||
>>> envs = TransformObservation(envs, single_func=scale_and_shift, vector_func=vector_scale_and_shift)
|
>>> envs = TransformObservation(envs, func=scale_and_shift, observation_space=new_obs_space)
|
||||||
>>> obs, info = envs.reset(seed=123)
|
>>> obs, info = envs.reset(seed=123)
|
||||||
>>> obs
|
>>> obs
|
||||||
array([[-1.9635296, -2.0892358, -2.055928 , -2.0631256],
|
array([[-1.9635296, -2.0892358, -2.055928 , -2.0631256],
|
||||||
@@ -55,16 +52,14 @@ class TransformObservation(VectorObservationWrapper):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env: VectorEnv,
|
env: VectorEnv,
|
||||||
vector_func: Callable[[ObsType], Any],
|
func: Callable[[ObsType], Any],
|
||||||
single_func: Callable[[ObsType], Any],
|
|
||||||
observation_space: Space | None = None,
|
observation_space: Space | None = None,
|
||||||
):
|
):
|
||||||
"""Constructor for the transform observation wrapper.
|
"""Constructor for the transform observation wrapper.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
env: The vector environment to wrap
|
env: The vector environment to wrap
|
||||||
vector_func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``.
|
func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``.
|
||||||
single_func: A function that will transform an individual observation, this function will be used for the final observation from the environment and is returned under ``info`` and not the normal observation.
|
|
||||||
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``.
|
observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``.
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
@@ -72,16 +67,11 @@ class TransformObservation(VectorObservationWrapper):
|
|||||||
if observation_space is not None:
|
if observation_space is not None:
|
||||||
self.observation_space = observation_space
|
self.observation_space = observation_space
|
||||||
|
|
||||||
self.vector_func = vector_func
|
self.func = func
|
||||||
self.single_func = single_func
|
|
||||||
|
|
||||||
def vector_observation(self, observation: ObsType) -> ObsType:
|
def observation(self, observations: ObsType) -> ObsType:
|
||||||
"""Apply function to the vector observation."""
|
"""Apply function to the vector observation."""
|
||||||
return self.vector_func(observation)
|
return self.func(observations)
|
||||||
|
|
||||||
def single_observation(self, observation: ObsType) -> ObsType:
|
|
||||||
"""Apply function to the single observation."""
|
|
||||||
return self.single_func(observation)
|
|
||||||
|
|
||||||
|
|
||||||
class VectorizeTransformObservation(VectorObservationWrapper):
|
class VectorizeTransformObservation(VectorObservationWrapper):
|
||||||
@@ -158,16 +148,16 @@ class VectorizeTransformObservation(VectorObservationWrapper):
|
|||||||
self.same_out = self.observation_space == self.env.observation_space
|
self.same_out = self.observation_space == self.env.observation_space
|
||||||
self.out = create_empty_array(self.single_observation_space, self.num_envs)
|
self.out = create_empty_array(self.single_observation_space, self.num_envs)
|
||||||
|
|
||||||
def vector_observation(self, observation: ObsType) -> ObsType:
|
def observation(self, observations: ObsType) -> ObsType:
|
||||||
"""Iterates over the vector observations applying the single-agent wrapper ``observation`` then concatenates the observations together again."""
|
"""Iterates over the vector observations applying the single-agent wrapper ``observation`` then concatenates the observations together again."""
|
||||||
if self.same_out:
|
if self.same_out:
|
||||||
return concatenate(
|
return concatenate(
|
||||||
self.single_observation_space,
|
self.single_observation_space,
|
||||||
tuple(
|
tuple(
|
||||||
self.wrapper.func(obs)
|
self.wrapper.func(obs)
|
||||||
for obs in iterate(self.observation_space, observation)
|
for obs in iterate(self.observation_space, observations)
|
||||||
),
|
),
|
||||||
observation,
|
observations,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return deepcopy(
|
return deepcopy(
|
||||||
@@ -175,16 +165,12 @@ class VectorizeTransformObservation(VectorObservationWrapper):
|
|||||||
self.single_observation_space,
|
self.single_observation_space,
|
||||||
tuple(
|
tuple(
|
||||||
self.wrapper.func(obs)
|
self.wrapper.func(obs)
|
||||||
for obs in iterate(self.env.observation_space, observation)
|
for obs in iterate(self.env.observation_space, observations)
|
||||||
),
|
),
|
||||||
self.out,
|
self.out,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def single_observation(self, observation: ObsType) -> ObsType:
|
|
||||||
"""Transforms a single observation using the wrapper transformation function."""
|
|
||||||
return self.wrapper.func(observation)
|
|
||||||
|
|
||||||
|
|
||||||
class FilterObservation(VectorizeTransformObservation):
|
class FilterObservation(VectorizeTransformObservation):
|
||||||
"""Vector wrapper for filtering dict or tuple observation spaces.
|
"""Vector wrapper for filtering dict or tuple observation spaces.
|
||||||
|
@@ -6,8 +6,14 @@ import jax.numpy as jnp # noqa: E402
|
|||||||
import jax.random as jrng # noqa: E402
|
import jax.random as jrng # noqa: E402
|
||||||
import numpy as np # noqa: E402
|
import numpy as np # noqa: E402
|
||||||
|
|
||||||
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional # noqa: E402
|
from gymnasium.envs.phys2d.cartpole import ( # noqa: E402
|
||||||
from gymnasium.envs.phys2d.pendulum import PendulumFunctional # noqa: E402
|
CartPoleFunctional,
|
||||||
|
CartPoleJaxVectorEnv,
|
||||||
|
)
|
||||||
|
from gymnasium.envs.phys2d.pendulum import ( # noqa: E402
|
||||||
|
PendulumFunctional,
|
||||||
|
PendulumJaxVectorEnv,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
|
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
|
||||||
@@ -105,3 +111,34 @@ def test_vmap(env_class):
|
|||||||
assert obs.dtype == jnp.float32
|
assert obs.dtype == jnp.float32
|
||||||
|
|
||||||
state = next_state
|
state = next_state
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("env_class", [CartPoleJaxVectorEnv, PendulumJaxVectorEnv])
|
||||||
|
def test_vectorized(env_class):
|
||||||
|
env = env_class(num_envs=10)
|
||||||
|
env.action_space.seed(0)
|
||||||
|
|
||||||
|
obs, info = env.reset(seed=0)
|
||||||
|
assert obs.shape == (10,) + env.single_observation_space.shape
|
||||||
|
assert isinstance(obs, np.ndarray)
|
||||||
|
assert isinstance(info, dict)
|
||||||
|
|
||||||
|
for t in range(100):
|
||||||
|
action = env.action_space.sample()
|
||||||
|
obs, reward, terminated, truncated, info = env.step(action)
|
||||||
|
|
||||||
|
assert obs.shape == (10,) + env.single_observation_space.shape
|
||||||
|
assert isinstance(obs, np.ndarray)
|
||||||
|
assert reward.shape == (10,)
|
||||||
|
assert isinstance(reward, np.ndarray)
|
||||||
|
assert terminated.shape == (10,)
|
||||||
|
assert isinstance(terminated, np.ndarray)
|
||||||
|
assert truncated.shape == (10,)
|
||||||
|
assert isinstance(truncated, np.ndarray)
|
||||||
|
assert isinstance(info, dict)
|
||||||
|
|
||||||
|
# These were removed in the new autoreset order
|
||||||
|
assert "final_observation" not in info
|
||||||
|
assert "final_info" not in info
|
||||||
|
assert "_final_observation" not in info
|
||||||
|
assert "_final_info" not in info
|
||||||
|
@@ -33,7 +33,7 @@ def test_create_async_vector_env(shared_memory):
|
|||||||
|
|
||||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||||
def test_reset_async_vector_env(shared_memory):
|
def test_reset_async_vector_env(shared_memory):
|
||||||
"""Test the reset of an sync vector environment with or without shared memory."""
|
"""Test the reset of async vector environment with or without shared memory."""
|
||||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||||
|
|
||||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||||
|
@@ -6,6 +6,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from gymnasium.spaces import Discrete
|
from gymnasium.spaces import Discrete
|
||||||
|
from gymnasium.utils.env_checker import data_equivalence
|
||||||
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
|
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
|
||||||
from tests.testing_env import GenericTestEnv
|
from tests.testing_env import GenericTestEnv
|
||||||
from tests.vector.testing_utils import make_env
|
from tests.vector.testing_utils import make_env
|
||||||
@@ -29,6 +30,7 @@ def test_vector_env_equal(shared_memory):
|
|||||||
async_observations, async_infos = async_env.reset(seed=0)
|
async_observations, async_infos = async_env.reset(seed=0)
|
||||||
sync_observations, sync_infos = sync_env.reset(seed=0)
|
sync_observations, sync_infos = sync_env.reset(seed=0)
|
||||||
assert np.all(async_observations == sync_observations)
|
assert np.all(async_observations == sync_observations)
|
||||||
|
assert data_equivalence(async_infos, sync_infos)
|
||||||
|
|
||||||
for _ in range(num_steps):
|
for _ in range(num_steps):
|
||||||
actions = async_env.action_space.sample()
|
actions = async_env.action_space.sample()
|
||||||
@@ -49,16 +51,11 @@ def test_vector_env_equal(shared_memory):
|
|||||||
sync_infos,
|
sync_infos,
|
||||||
) = sync_env.step(actions)
|
) = sync_env.step(actions)
|
||||||
|
|
||||||
if any(sync_terminations) or any(sync_truncations):
|
|
||||||
assert "final_observation" in async_infos
|
|
||||||
assert "_final_observation" in async_infos
|
|
||||||
assert "final_observation" in sync_infos
|
|
||||||
assert "_final_observation" in sync_infos
|
|
||||||
|
|
||||||
assert np.all(async_observations == sync_observations)
|
assert np.all(async_observations == sync_observations)
|
||||||
assert np.all(async_rewards == sync_rewards)
|
assert np.all(async_rewards == sync_rewards)
|
||||||
assert np.all(async_terminations == sync_terminations)
|
assert np.all(async_terminations == sync_terminations)
|
||||||
assert np.all(async_truncations == sync_truncations)
|
assert np.all(async_truncations == sync_truncations)
|
||||||
|
assert data_equivalence(async_infos, sync_infos)
|
||||||
|
|
||||||
async_env.close()
|
async_env.close()
|
||||||
sync_env.close()
|
sync_env.close()
|
||||||
@@ -115,14 +112,13 @@ def test_final_obs_info(vectoriser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
obs, _, termination, _, info = env.step([3])
|
obs, _, termination, _, info = env.step([3])
|
||||||
|
assert obs == np.array([0]) and info == {"action": 3, "_action": np.array([True])}
|
||||||
|
|
||||||
|
obs, _, terminated, _, info = env.step([4])
|
||||||
assert (
|
assert (
|
||||||
obs == np.array([0])
|
obs == np.array([0])
|
||||||
and termination == np.array([True])
|
and termination == np.array([True])
|
||||||
and info["reset"] == np.array([True])
|
and info["reset"] == np.array([True])
|
||||||
)
|
)
|
||||||
assert "final_observation" in info and "final_info" in info
|
|
||||||
assert info["final_observation"] == np.array([0]) and info["final_info"] == {
|
|
||||||
"action": 3
|
|
||||||
}
|
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
@@ -1,66 +1,164 @@
|
|||||||
"""Test the vector environment information."""
|
"""Test the vector environment information."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, SupportsFloat
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.vector.sync_vector_env import SyncVectorEnv
|
from gymnasium.core import ActType, ObsType
|
||||||
from tests.vector.testing_utils import make_env
|
from gymnasium.spaces import Box, Discrete
|
||||||
|
from gymnasium.utils.env_checker import data_equivalence
|
||||||
|
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv, VectorEnv
|
||||||
|
|
||||||
|
|
||||||
ENV_ID = "CartPole-v1"
|
def test_vector_add_info():
|
||||||
NUM_ENVS = 3
|
env = VectorEnv()
|
||||||
ENV_STEPS = 50
|
|
||||||
SEED = 42
|
# Test num-envs==1 then expand_dims(sub-env-info) == vector-infos
|
||||||
|
env.num_envs = 1
|
||||||
|
sub_env_info = {"a": 0, "b": 0.0, "c": None, "d": np.zeros((2,)), "e": Discrete(1)}
|
||||||
|
vector_infos = env._add_info({}, sub_env_info, 0)
|
||||||
|
expected_vector_infos = {
|
||||||
|
"a": np.array([0]),
|
||||||
|
"b": np.array([0.0]),
|
||||||
|
"c": np.array([None], dtype=object),
|
||||||
|
"d": np.zeros(
|
||||||
|
(
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
"e": np.array([Discrete(1)], dtype=object),
|
||||||
|
"_a": np.array([True]),
|
||||||
|
"_b": np.array([True]),
|
||||||
|
"_c": np.array([True]),
|
||||||
|
"_d": np.array([True]),
|
||||||
|
"_e": np.array([True]),
|
||||||
|
}
|
||||||
|
assert data_equivalence(vector_infos, expected_vector_infos)
|
||||||
|
|
||||||
|
# Thought: num-envs>1 then vector-infos should have the same structure as sub-env-info
|
||||||
|
env.num_envs = 3
|
||||||
|
sub_env_infos = [
|
||||||
|
{"a": 0, "b": 0.0, "c": None, "d": np.zeros((2,)), "e": Discrete(1)},
|
||||||
|
{"a": 1, "b": 1.0, "c": None, "d": np.zeros((2,)), "e": Discrete(2)},
|
||||||
|
{"a": 2, "b": 2.0, "c": None, "d": np.zeros((2,)), "e": Discrete(3)},
|
||||||
|
]
|
||||||
|
|
||||||
|
vector_infos = {}
|
||||||
|
for i, info in enumerate(sub_env_infos):
|
||||||
|
vector_infos = env._add_info(vector_infos, info, i)
|
||||||
|
|
||||||
|
expected_vector_infos = {
|
||||||
|
"a": np.array([0, 1, 2]),
|
||||||
|
"b": np.array([0.0, 1.0, 2.0]),
|
||||||
|
"c": np.array([None, None, None], dtype=object),
|
||||||
|
"d": np.zeros((3, 2)),
|
||||||
|
"e": np.array([Discrete(1), Discrete(2), Discrete(3)], dtype=object),
|
||||||
|
"_a": np.array([True, True, True]),
|
||||||
|
"_b": np.array([True, True, True]),
|
||||||
|
"_c": np.array([True, True, True]),
|
||||||
|
"_d": np.array([True, True, True]),
|
||||||
|
"_e": np.array([True, True, True]),
|
||||||
|
}
|
||||||
|
assert data_equivalence(vector_infos, expected_vector_infos)
|
||||||
|
|
||||||
|
# Test different structures of sub-infos
|
||||||
|
env.num_envs = 3
|
||||||
|
sub_env_infos = [
|
||||||
|
{"a": 1, "b": 1.0},
|
||||||
|
{"c": None, "d": np.zeros((2,))},
|
||||||
|
{"e": Discrete(3)},
|
||||||
|
]
|
||||||
|
|
||||||
|
vector_infos = {}
|
||||||
|
for i, info in enumerate(sub_env_infos):
|
||||||
|
vector_infos = env._add_info(vector_infos, info, i)
|
||||||
|
|
||||||
|
expected_vector_infos = {
|
||||||
|
"a": np.array([1, 0, 0]),
|
||||||
|
"b": np.array([1.0, 0.0, 0.0]),
|
||||||
|
"c": np.array([None, None, None], dtype=object),
|
||||||
|
"d": np.zeros((3, 2)),
|
||||||
|
"e": np.array([None, None, Discrete(3)], dtype=object),
|
||||||
|
"_a": np.array([True, False, False]),
|
||||||
|
"_b": np.array([True, False, False]),
|
||||||
|
"_c": np.array([False, True, False]),
|
||||||
|
"_d": np.array([False, True, False]),
|
||||||
|
"_e": np.array([False, False, True]),
|
||||||
|
}
|
||||||
|
assert data_equivalence(vector_infos, expected_vector_infos)
|
||||||
|
|
||||||
|
# Test recursive structure
|
||||||
|
env.num_envs = 3
|
||||||
|
sub_env_infos = [
|
||||||
|
{"episode": {"a": 1, "b": 1.0}},
|
||||||
|
{"episode": {"a": 2, "b": 2.0}, "a": 1},
|
||||||
|
{"a": 2},
|
||||||
|
]
|
||||||
|
|
||||||
|
vector_infos = {}
|
||||||
|
for i, info in enumerate(sub_env_infos):
|
||||||
|
vector_infos = env._add_info(vector_infos, info, i)
|
||||||
|
|
||||||
|
expected_vector_infos = {
|
||||||
|
"episode": {
|
||||||
|
"a": np.array([1, 2, 0]),
|
||||||
|
"b": np.array([1.0, 2.0, 0.0]),
|
||||||
|
"_a": np.array([True, True, False]),
|
||||||
|
"_b": np.array([True, True, False]),
|
||||||
|
},
|
||||||
|
"_episode": np.array([True, True, False]),
|
||||||
|
"a": np.array([0, 1, 2]),
|
||||||
|
"_a": np.array([False, True, True]),
|
||||||
|
}
|
||||||
|
assert data_equivalence(vector_infos, expected_vector_infos)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("vectorization_mode", ["async", "sync"])
|
class ReturnInfoEnv(gym.Env):
|
||||||
def test_vector_env_info(vectorization_mode: str):
|
def __init__(self, infos):
|
||||||
"""Test vector environment info for different vectorization modes."""
|
self.observation_space = Box(0, 1)
|
||||||
env = gym.make_vec(
|
self.action_space = Box(0, 1)
|
||||||
ENV_ID,
|
|
||||||
num_envs=NUM_ENVS,
|
self.infos = infos
|
||||||
vectorization_mode=vectorization_mode,
|
|
||||||
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: int | None = None,
|
||||||
|
options: dict[str, Any] | None = None,
|
||||||
|
) -> tuple[ObsType, dict[str, Any]]:
|
||||||
|
return self.observation_space.sample(), self.infos[0]
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, action: ActType
|
||||||
|
) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||||
|
return self.observation_space.sample(), 0, True, False, self.infos[1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("vectorizer", [AsyncVectorEnv, SyncVectorEnv])
|
||||||
|
def test_vectorizers(vectorizer):
|
||||||
|
vec_env = vectorizer(
|
||||||
|
[
|
||||||
|
lambda: ReturnInfoEnv([{"a": 1}, {"c": np.array([1, 2])}]),
|
||||||
|
lambda: ReturnInfoEnv([{"a": 2, "b": 3}, {"c": np.array([3, 4])}]),
|
||||||
|
]
|
||||||
)
|
)
|
||||||
env.reset(seed=SEED)
|
|
||||||
for _ in range(ENV_STEPS):
|
|
||||||
env.action_space.seed(SEED)
|
|
||||||
action = env.action_space.sample()
|
|
||||||
_, _, terminateds, truncateds, infos = env.step(action)
|
|
||||||
if any(terminateds) or any(truncateds):
|
|
||||||
assert len(infos["final_observation"]) == NUM_ENVS
|
|
||||||
assert len(infos["_final_observation"]) == NUM_ENVS
|
|
||||||
|
|
||||||
assert isinstance(infos["final_observation"], np.ndarray)
|
reset_expected_infos = {
|
||||||
assert isinstance(infos["_final_observation"], np.ndarray)
|
"a": np.array([1, 2]),
|
||||||
|
"b": np.array([0, 3]),
|
||||||
|
"_a": np.array([True, True]),
|
||||||
|
"_b": np.array([False, True]),
|
||||||
|
}
|
||||||
|
step_expected_infos = {
|
||||||
|
"c": np.array([[1, 2], [3, 4]]),
|
||||||
|
"_c": np.array([True, True]),
|
||||||
|
}
|
||||||
|
|
||||||
for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)):
|
_, reset_info = vec_env.reset()
|
||||||
if terminated or truncated:
|
assert data_equivalence(reset_info, reset_expected_infos)
|
||||||
assert infos["_final_observation"][i]
|
_, _, _, _, step_info = vec_env.step(vec_env.action_space.sample())
|
||||||
else:
|
assert data_equivalence(step_info, step_expected_infos)
|
||||||
assert not infos["_final_observation"][i]
|
|
||||||
assert infos["final_observation"][i] is None
|
|
||||||
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("concurrent_ends", [1, 2, 3])
|
|
||||||
def test_vector_env_info_concurrent_termination(concurrent_ends):
|
|
||||||
"""Test the vector environment information works with concurrent termination."""
|
|
||||||
# envs that need to terminate together will have the same action
|
|
||||||
actions = [0] * concurrent_ends + [1] * (NUM_ENVS - concurrent_ends)
|
|
||||||
envs = [make_env(ENV_ID, SEED) for _ in range(NUM_ENVS)]
|
|
||||||
envs = SyncVectorEnv(envs)
|
|
||||||
|
|
||||||
for _ in range(ENV_STEPS):
|
|
||||||
_, _, terminateds, truncateds, infos = envs.step(actions)
|
|
||||||
if any(terminateds) or any(truncateds):
|
|
||||||
for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)):
|
|
||||||
if i < concurrent_ends:
|
|
||||||
assert terminated or truncated
|
|
||||||
assert infos["_final_observation"][i]
|
|
||||||
else:
|
|
||||||
assert not infos["_final_observation"][i]
|
|
||||||
assert infos["final_observation"][i] is None
|
|
||||||
return
|
|
||||||
|
|
||||||
envs.close()
|
|
||||||
|
@@ -45,19 +45,15 @@ def test_autoreset_wrapper_autoreset():
|
|||||||
assert info == {"count": 2}
|
assert info == {"count": 2}
|
||||||
|
|
||||||
obs, reward, terminated, truncated, info = env.step(action)
|
obs, reward, terminated, truncated, info = env.step(action)
|
||||||
assert obs == np.array([0])
|
assert obs == np.array([3])
|
||||||
assert (terminated or truncated) is True
|
assert (terminated or truncated) is True
|
||||||
assert reward == 1
|
assert reward == 1
|
||||||
assert info == {
|
assert info == {"count": 3}
|
||||||
"count": 0,
|
|
||||||
"final_observation": np.array([3]),
|
|
||||||
"final_info": {"count": 3},
|
|
||||||
}
|
|
||||||
|
|
||||||
obs, reward, terminated, truncated, info = env.step(action)
|
obs, reward, terminated, truncated, info = env.step(action)
|
||||||
assert obs == np.array([1])
|
assert obs == np.array([0])
|
||||||
assert reward == 0
|
assert reward == 0
|
||||||
assert (terminated or truncated) is False
|
assert (terminated or truncated) is False
|
||||||
assert info == {"count": 1}
|
assert info == {"count": 0}
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
@@ -1,21 +1,22 @@
|
|||||||
"""Test suite for DictInfoTolist wrapper."""
|
"""Test suite for DictInfoTolist wrapper."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.wrappers.vector import DictInfoToList, RecordEpisodeStatistics
|
from gymnasium.core import ObsType
|
||||||
|
from gymnasium.spaces import Discrete
|
||||||
|
from gymnasium.utils.env_checker import data_equivalence
|
||||||
|
from gymnasium.vector import VectorEnv
|
||||||
|
from gymnasium.wrappers.vector import DictInfoToList
|
||||||
|
|
||||||
|
|
||||||
ENV_ID = "CartPole-v1"
|
def test_usage_in_vector_env(env_id: str = "CartPole-v1", num_envs: int = 3):
|
||||||
NUM_ENVS = 3
|
env = gym.make(env_id, disable_env_checker=True)
|
||||||
ENV_STEPS = 50
|
vector_env = gym.make_vec(env_id, num_envs=num_envs)
|
||||||
SEED = 42
|
|
||||||
|
|
||||||
|
|
||||||
def test_usage_in_vector_env():
|
|
||||||
env = gym.make(ENV_ID, disable_env_checker=True)
|
|
||||||
vector_env = gym.make_vec(ENV_ID, num_envs=NUM_ENVS, vectorization_mode="sync")
|
|
||||||
|
|
||||||
DictInfoToList(vector_env)
|
DictInfoToList(vector_env)
|
||||||
|
|
||||||
@@ -23,40 +24,140 @@ def test_usage_in_vector_env():
|
|||||||
DictInfoToList(env)
|
DictInfoToList(env)
|
||||||
|
|
||||||
|
|
||||||
def test_info_to_list():
|
class ResetOptionAsInfo(VectorEnv):
|
||||||
env_to_wrap = gym.make_vec(ENV_ID, num_envs=NUM_ENVS, vectorization_mode="sync")
|
"""Minimal implementation to test the conversion of vector dict info to list info."""
|
||||||
wrapped_env = DictInfoToList(env_to_wrap)
|
|
||||||
wrapped_env.action_space.seed(SEED)
|
|
||||||
_, info = wrapped_env.reset(seed=SEED)
|
|
||||||
assert isinstance(info, list)
|
|
||||||
assert len(info) == NUM_ENVS
|
|
||||||
|
|
||||||
for _ in range(ENV_STEPS):
|
def reset(
|
||||||
action = wrapped_env.action_space.sample()
|
self,
|
||||||
_, _, terminateds, truncateds, list_info = wrapped_env.step(action)
|
*,
|
||||||
for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)):
|
seed: int | list[int] | None = None,
|
||||||
if terminated or truncated:
|
options: dict[str, Any] | None = None, # options are passed are the info output
|
||||||
assert "final_observation" in list_info[i]
|
) -> tuple[ObsType, dict[str, Any]]:
|
||||||
else:
|
return None, options
|
||||||
assert "final_observation" not in list_info[i]
|
|
||||||
|
|
||||||
|
|
||||||
def test_info_to_list_statistics():
|
def test_update_info():
|
||||||
env_to_wrap = gym.make_vec(ENV_ID, num_envs=NUM_ENVS, vectorization_mode="sync")
|
env = DictInfoToList(ResetOptionAsInfo())
|
||||||
wrapped_env = DictInfoToList(RecordEpisodeStatistics(env_to_wrap))
|
|
||||||
_, info = wrapped_env.reset(seed=SEED)
|
|
||||||
wrapped_env.action_space.seed(SEED)
|
|
||||||
assert isinstance(info, list)
|
|
||||||
assert len(info) == NUM_ENVS
|
|
||||||
|
|
||||||
for _ in range(ENV_STEPS):
|
# Test num-envs==1 then expand_dims(sub-env-info) == vector-infos
|
||||||
action = wrapped_env.action_space.sample()
|
env.unwrapped.num_envs = 1
|
||||||
_, _, terminateds, truncateds, list_info = wrapped_env.step(action)
|
|
||||||
for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)):
|
vector_infos = {
|
||||||
if terminated or truncated:
|
"a": np.array([0]),
|
||||||
assert "episode" in list_info[i]
|
"b": np.array([0.0]),
|
||||||
for stats in ["r", "l", "t"]:
|
"c": np.array([None], dtype=object),
|
||||||
assert stats in list_info[i]["episode"]
|
"d": np.zeros(
|
||||||
assert np.isscalar(list_info[i]["episode"][stats])
|
(
|
||||||
else:
|
1,
|
||||||
assert "episode" not in list_info[i]
|
2,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
"e": np.array([Discrete(1)], dtype=object),
|
||||||
|
"_a": np.array([True]),
|
||||||
|
"_b": np.array([True]),
|
||||||
|
"_c": np.array([True]),
|
||||||
|
"_d": np.array([True]),
|
||||||
|
"_e": np.array([True]),
|
||||||
|
}
|
||||||
|
_, list_info = env.reset(options=vector_infos)
|
||||||
|
expected_list_info = [
|
||||||
|
{
|
||||||
|
"a": np.int64(0),
|
||||||
|
"b": np.float64(0.0),
|
||||||
|
"c": None,
|
||||||
|
"d": np.zeros((2,)),
|
||||||
|
"e": Discrete(1),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
assert data_equivalence(list_info, expected_list_info)
|
||||||
|
|
||||||
|
# Thought: num-envs>1 then vector-infos should have the same structure as sub-env-info
|
||||||
|
env.unwrapped.num_envs = 3
|
||||||
|
|
||||||
|
vector_infos = {
|
||||||
|
"a": np.array([0, 1, 2]),
|
||||||
|
"b": np.array([0.0, 1.0, 2.0]),
|
||||||
|
"c": np.array([None, None, None], dtype=object),
|
||||||
|
"d": np.zeros((3, 2)),
|
||||||
|
"e": np.array([Discrete(1), Discrete(2), Discrete(3)], dtype=object),
|
||||||
|
"_a": np.array([True, True, True]),
|
||||||
|
"_b": np.array([True, True, True]),
|
||||||
|
"_c": np.array([True, True, True]),
|
||||||
|
"_d": np.array([True, True, True]),
|
||||||
|
"_e": np.array([True, True, True]),
|
||||||
|
}
|
||||||
|
_, list_info = env.reset(options=vector_infos)
|
||||||
|
expected_list_info = [
|
||||||
|
{
|
||||||
|
"a": np.int64(0),
|
||||||
|
"b": np.float64(0.0),
|
||||||
|
"c": None,
|
||||||
|
"d": np.zeros((2,)),
|
||||||
|
"e": Discrete(1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"a": np.int64(1),
|
||||||
|
"b": np.float64(1.0),
|
||||||
|
"c": None,
|
||||||
|
"d": np.zeros((2,)),
|
||||||
|
"e": Discrete(2),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"a": np.int64(2),
|
||||||
|
"b": np.float64(2.0),
|
||||||
|
"c": None,
|
||||||
|
"d": np.zeros((2,)),
|
||||||
|
"e": Discrete(3),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
assert list_info[0].keys() == expected_list_info[0].keys()
|
||||||
|
for key in list_info[0].keys():
|
||||||
|
assert data_equivalence(list_info[0][key], expected_list_info[0][key])
|
||||||
|
assert data_equivalence(list_info, expected_list_info)
|
||||||
|
|
||||||
|
# Test different structures of sub-infos
|
||||||
|
env.unwrapped.num_envs = 3
|
||||||
|
|
||||||
|
vector_infos = {
|
||||||
|
"a": np.array([1, 0, 0]),
|
||||||
|
"_a": np.array([True, False, False]),
|
||||||
|
"b": np.array([1.0, 0.0, 0.0]),
|
||||||
|
"_b": np.array([True, False, False]),
|
||||||
|
"c": np.array([None, None, None], dtype=object),
|
||||||
|
"_c": np.array([False, True, False]),
|
||||||
|
"_d": np.array([False, True, False]),
|
||||||
|
"d": np.zeros((3, 2)),
|
||||||
|
"e": np.array([None, None, Discrete(3)], dtype=object),
|
||||||
|
"_e": np.array([False, False, True]),
|
||||||
|
}
|
||||||
|
_, list_info = env.reset(options=vector_infos)
|
||||||
|
expected_list_info = [
|
||||||
|
{"a": np.int64(1), "b": np.float64(1.0)},
|
||||||
|
{"c": None, "d": np.zeros((2,))},
|
||||||
|
{"e": Discrete(3)},
|
||||||
|
]
|
||||||
|
assert data_equivalence(list_info, expected_list_info)
|
||||||
|
|
||||||
|
# Test recursive structure
|
||||||
|
env.unwrapped.num_envs = 3
|
||||||
|
|
||||||
|
vector_infos = {
|
||||||
|
"episode": {
|
||||||
|
"a": np.array([1, 2, 0]),
|
||||||
|
"b": np.array([1.0, 2.0, 0.0]),
|
||||||
|
"_a": np.array([True, True, False]),
|
||||||
|
"_b": np.array([True, True, False]),
|
||||||
|
},
|
||||||
|
"_episode": np.array([True, True, False]),
|
||||||
|
"a": np.array([0, 1, 2]),
|
||||||
|
"_a": np.array([False, True, True]),
|
||||||
|
}
|
||||||
|
_, list_info = env.reset(options=vector_infos)
|
||||||
|
expected_list_info = [
|
||||||
|
{"episode": {"a": np.int64(1), "b": np.float64(1.0)}},
|
||||||
|
{"episode": {"a": np.int64(2), "b": np.float64(2.0)}, "a": np.int64(1)},
|
||||||
|
{"a": np.int64(2)},
|
||||||
|
]
|
||||||
|
assert data_equivalence(list_info, expected_list_info)
|
||||||
|
71
tests/wrappers/vector/test_record_episode_statistics.py
Normal file
71
tests/wrappers/vector/test_record_episode_statistics.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium.utils.env_checker import data_equivalence
|
||||||
|
from gymnasium.vector import VectorEnv
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_envs", (1, 3))
|
||||||
|
def test_record_episode_statistics(num_envs, env_id="CartPole-v1", num_steps=100):
|
||||||
|
wrapper_vector_env: VectorEnv = gym.wrappers.vector.RecordEpisodeStatistics(
|
||||||
|
gym.make_vec(id=env_id, num_envs=num_envs, vectorization_mode="sync"),
|
||||||
|
)
|
||||||
|
vector_wrapper_env = gym.make_vec(
|
||||||
|
id=env_id,
|
||||||
|
num_envs=num_envs,
|
||||||
|
vectorization_mode="sync",
|
||||||
|
wrappers=(gym.wrappers.RecordEpisodeStatistics,),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert wrapper_vector_env.action_space == vector_wrapper_env.action_space
|
||||||
|
assert wrapper_vector_env.observation_space == vector_wrapper_env.observation_space
|
||||||
|
assert (
|
||||||
|
wrapper_vector_env.single_action_space == vector_wrapper_env.single_action_space
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
wrapper_vector_env.single_observation_space
|
||||||
|
== vector_wrapper_env.single_observation_space
|
||||||
|
)
|
||||||
|
|
||||||
|
assert wrapper_vector_env.num_envs == vector_wrapper_env.num_envs
|
||||||
|
|
||||||
|
wrapper_vector_obs, wrapper_vector_info = wrapper_vector_env.reset(seed=123)
|
||||||
|
vector_wrapper_obs, vector_wrapper_info = vector_wrapper_env.reset(seed=123)
|
||||||
|
|
||||||
|
assert data_equivalence(wrapper_vector_obs, vector_wrapper_obs)
|
||||||
|
assert data_equivalence(wrapper_vector_info, vector_wrapper_info)
|
||||||
|
|
||||||
|
for _ in range(num_steps):
|
||||||
|
action = wrapper_vector_env.action_space.sample()
|
||||||
|
(
|
||||||
|
wrapper_vector_obs,
|
||||||
|
wrapper_vector_reward,
|
||||||
|
wrapper_vector_terminated,
|
||||||
|
wrapper_vector_truncated,
|
||||||
|
wrapper_vector_info,
|
||||||
|
) = wrapper_vector_env.step(action)
|
||||||
|
(
|
||||||
|
vector_wrapper_obs,
|
||||||
|
vector_wrapper_reward,
|
||||||
|
vector_wrapper_terminated,
|
||||||
|
vector_wrapper_truncated,
|
||||||
|
vector_wrapper_info,
|
||||||
|
) = vector_wrapper_env.step(action)
|
||||||
|
|
||||||
|
data_equivalence(wrapper_vector_obs, vector_wrapper_obs)
|
||||||
|
data_equivalence(wrapper_vector_reward, vector_wrapper_reward)
|
||||||
|
data_equivalence(wrapper_vector_terminated, vector_wrapper_terminated)
|
||||||
|
data_equivalence(wrapper_vector_truncated, vector_wrapper_truncated)
|
||||||
|
|
||||||
|
if "episode" in wrapper_vector_info:
|
||||||
|
assert "episode" in vector_wrapper_info
|
||||||
|
|
||||||
|
wrapper_vector_time = wrapper_vector_info["episode"].pop("t")
|
||||||
|
vector_wrapper_time = vector_wrapper_info["episode"].pop("t")
|
||||||
|
assert wrapper_vector_time.shape == vector_wrapper_time.shape
|
||||||
|
assert wrapper_vector_time.dtype == vector_wrapper_time.dtype
|
||||||
|
|
||||||
|
data_equivalence(wrapper_vector_info, vector_wrapper_info)
|
||||||
|
|
||||||
|
wrapper_vector_env.close()
|
||||||
|
vector_wrapper_env.close()
|
@@ -42,21 +42,21 @@ def custom_environments():
|
|||||||
("CustomDictEnv-v0", "FilterObservation", {"filter_keys": ["a"]}),
|
("CustomDictEnv-v0", "FilterObservation", {"filter_keys": ["a"]}),
|
||||||
("CartPole-v1", "FlattenObservation", {}),
|
("CartPole-v1", "FlattenObservation", {}),
|
||||||
("CarRacing-v2", "GrayscaleObservation", {}),
|
("CarRacing-v2", "GrayscaleObservation", {}),
|
||||||
# ("CarRacing-v2", "ResizeObservation", {"shape": (35, 45)}),
|
("CarRacing-v2", "ResizeObservation", {"shape": (35, 45)}),
|
||||||
("CarRacing-v2", "ReshapeObservation", {"shape": (96, 48, 6)}),
|
("CarRacing-v2", "ReshapeObservation", {"shape": (96, 48, 6)}),
|
||||||
("CartPole-v1", "RescaleObservation", {"min_obs": 0, "max_obs": 1}),
|
("CartPole-v1", "RescaleObservation", {"min_obs": 0, "max_obs": 1}),
|
||||||
("CartPole-v1", "DtypeObservation", {"dtype": np.int32}),
|
("CartPole-v1", "DtypeObservation", {"dtype": np.int32}),
|
||||||
# ("CartPole-v1", "RenderObservation", {}),
|
# ("CartPole-v1", "RenderObservation", {}), # not implemented
|
||||||
# ("CartPole-v1", "TimeAwareObservation", {}),
|
# ("CartPole-v1", "TimeAwareObservation", {}), # not implemented
|
||||||
# ("CartPole-v1", "FrameStackObservation", {}),
|
# ("CartPole-v1", "FrameStackObservation", {}), # not implemented
|
||||||
# ("CartPole-v1", "DelayObservation", {}),
|
# ("CartPole-v1", "DelayObservation", {}), # not implemented
|
||||||
("MountainCarContinuous-v0", "ClipAction", {}),
|
("MountainCarContinuous-v0", "ClipAction", {}),
|
||||||
(
|
(
|
||||||
"MountainCarContinuous-v0",
|
"MountainCarContinuous-v0",
|
||||||
"RescaleAction",
|
"RescaleAction",
|
||||||
{"min_action": 1, "max_action": 2},
|
{"min_action": 1, "max_action": 2},
|
||||||
),
|
),
|
||||||
("CartPole-v1", "ClipReward", {"min_reward": 0.25, "max_reward": 0.75}),
|
("CartPole-v1", "ClipReward", {"min_reward": -0.25, "max_reward": 0.75}),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
def test_vector_wrapper_equivalence(
|
def test_vector_wrapper_equivalence(
|
||||||
|
Reference in New Issue
Block a user