diff --git a/gymnasium/envs/functional_jax_env.py b/gymnasium/envs/functional_jax_env.py index ee0ec7552..2fc5e55c7 100644 --- a/gymnasium/envs/functional_jax_env.py +++ b/gymnasium/envs/functional_jax_env.py @@ -151,6 +151,8 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv): self.steps = jnp.zeros(self.num_envs, dtype=jnp.int32) + self.autoreset_envs = jnp.zeros(self.num_envs, dtype=jnp.bool_) + self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box) if self.render_mode == "rgb_array": @@ -214,10 +216,9 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv): info = self.func_env.transition_info(self.state, action, next_state) done = jnp.logical_or(terminated, truncated) - if jnp.any(done): - final_obs = self.func_env.observation(next_state) - to_reset = jnp.where(done)[0] + if jnp.any(self.autoreset_envs): + to_reset = jnp.where(self.autoreset_envs)[0] reset_count = to_reset.shape[0] rng, self.rng = jrng.split(self.rng) @@ -228,34 +229,16 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv): next_state = self.state.at[to_reset].set(new_initials) self.steps = self.steps.at[to_reset].set(0) - # Get the final observations and infos - info["final_observation"] = np.array([None for _ in range(self.num_envs)]) - info["final_info"] = np.array([None for _ in range(self.num_envs)]) - - info["_final_observation"] = np.array([False for _ in range(self.num_envs)]) - info["_final_info"] = np.array([False for _ in range(self.num_envs)]) - - # TODO: this can maybe be optimized, but right now I don't know how - for i in to_reset: - info["final_observation"][i] = final_obs[i] - info["final_info"][i] = { - k: v[i] - for k, v in info.items() - if k - not in { - "final_observation", - "final_info", - "_final_observation", - "_final_info", - } - } - - info["_final_observation"][i] = True - info["_final_info"][i] = True + self.autoreset_envs = done observation = self.func_env.observation(next_state) observation = jax_to_numpy(observation) + reward = jax_to_numpy(reward) + + terminated = jax_to_numpy(terminated) + truncated = jax_to_numpy(truncated) + self.state = next_state return observation, reward, terminated, truncated, info diff --git a/gymnasium/vector/async_vector_env.py b/gymnasium/vector/async_vector_env.py index ea724cf25..4edef093b 100644 --- a/gymnasium/vector/async_vector_env.py +++ b/gymnasium/vector/async_vector_env.py @@ -639,6 +639,7 @@ def _async_worker( env = env_fn() observation_space = env.observation_space action_space = env.action_space + autoreset = False parent_pipe.close() @@ -653,20 +654,21 @@ def _async_worker( observation_space, index, observation, shared_memory ) observation = None + autoreset = False pipe.send(((observation, info), True)) elif command == "step": - ( - observation, - reward, - terminated, - truncated, - info, - ) = env.step(data) - if terminated or truncated: - old_observation, old_info = observation, info + if autoreset: observation, info = env.reset() - info["final_observation"] = old_observation - info["final_info"] = old_info + reward, terminated, truncated = 0, False, False + else: + ( + observation, + reward, + terminated, + truncated, + info, + ) = env.step(data) + autoreset = terminated or truncated if shared_memory: write_to_shared_memory( diff --git a/gymnasium/vector/sync_vector_env.py b/gymnasium/vector/sync_vector_env.py index eed598bd4..a9f234f2e 100644 --- a/gymnasium/vector/sync_vector_env.py +++ b/gymnasium/vector/sync_vector_env.py @@ -98,6 +98,8 @@ class SyncVectorEnv(VectorEnv): self._terminations = np.zeros((self.num_envs,), dtype=np.bool_) self._truncations = np.zeros((self.num_envs,), dtype=np.bool_) + self._autoreset_envs = np.zeros((self.num_envs,), dtype=np.bool_) + def reset( self, *, @@ -150,22 +152,21 @@ class SyncVectorEnv(VectorEnv): actions = iterate(self.action_space, actions) observations, infos = [], {} - for i, (env, action) in enumerate(zip(self.envs, actions)): - ( - env_obs, - self._rewards[i], - self._terminations[i], - self._truncations[i], - env_info, - ) = env.step(action) + for i, action in enumerate(actions): + if self._autoreset_envs[i]: + env_obs, env_info = self.envs[i].reset() - # If sub-environments terminates or truncates then save the obs and info to the batched info - if self._terminations[i] or self._truncations[i]: - old_observation, old_info = env_obs, env_info - env_obs, env_info = env.reset() - - env_info["final_observation"] = old_observation - env_info["final_info"] = old_info + self._rewards[i] = 0.0 + self._terminations[i] = False + self._truncations[i] = False + else: + ( + env_obs, + self._rewards[i], + self._terminations[i], + self._truncations[i], + env_info, + ) = self.envs[i].step(action) observations.append(env_obs) infos = self._add_info(infos, env_info, i) @@ -174,6 +175,7 @@ class SyncVectorEnv(VectorEnv): self._observations = concatenate( self.single_observation_space, observations, self._observations ) + self._autoreset_envs = np.logical_or(self._terminations, self._truncations) return ( deepcopy(self._observations) if self.copy else self._observations, diff --git a/gymnasium/vector/vector_env.py b/gymnasium/vector/vector_env.py index 6150977d1..63fe95f64 100644 --- a/gymnasium/vector/vector_env.py +++ b/gymnasium/vector/vector_env.py @@ -231,7 +231,7 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]): return self def _add_info( - self, infos: dict[str, Any], info: dict[str, Any], env_num: int + self, vector_infos: dict[str, Any], env_info: dict[str, Any], env_num: int ) -> dict[str, Any]: """Add env info to the info dictionary of the vectorized environment. @@ -241,48 +241,51 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]): whether or not the i-indexed environment has this `info`. Args: - infos (dict): the infos of the vectorized environment - info (dict): the info coming from the single environment + vector_infos (dict): the infos of the vectorized environment + env_info (dict): the info coming from the single environment env_num (int): the index of the single environment Returns: infos (dict): the (updated) infos of the vectorized environment - """ - for k in info.keys(): - if k not in infos: - info_array, array_mask = self._init_info_arrays(type(info[k])) + for key, value in env_info.items(): + # If value is a dictionary, then we apply the `_add_info` recursively. + if isinstance(value, dict): + array = self._add_info(vector_infos.get(key, {}), value, env_num) + # Otherwise, we are a base case to group the data else: - info_array, array_mask = infos[k], infos[f"_{k}"] + # If the key doesn't exist in the vector infos, then we can create an array of that batch type + if key not in vector_infos: + if type(value) in [int, float, bool] or issubclass( + type(value), np.number + ): + array = np.zeros(self.num_envs, dtype=type(value)) + elif isinstance(value, np.ndarray): + # We assume that all instances of the np.array info are of the same shape + array = np.zeros( + (self.num_envs, *value.shape), dtype=value.dtype + ) + else: + # For unknown objects, we use a Numpy object array + array = np.full(self.num_envs, fill_value=None, dtype=object) + # Otherwise, just use the array that already exists + else: + array = vector_infos[key] - info_array[env_num], array_mask[env_num] = info[k], True - infos[k], infos[f"_{k}"] = info_array, array_mask - return infos + # Assign the data in the `env_num` position + # We only want to run this for the base-case data (not recursive data forcing the ugly function structure) + array[env_num] = value - def _init_info_arrays(self, dtype: type) -> tuple[np.ndarray, np.ndarray]: - """Initialize the info array. + # Get the array mask and if it doesn't already exist then create a zero bool array + array_mask = vector_infos.get( + f"_{key}", np.zeros(self.num_envs, dtype=np.bool_) + ) + array_mask[env_num] = True - Initialize the info array. If the dtype is numeric - the info array will have the same dtype, otherwise - will be an array of `None`. Also, a boolean array - of the same length is returned. It will be used for - assessing which environment has info data. + # Update the vector info with the updated data and mask information + vector_infos[key], vector_infos[f"_{key}"] = array, array_mask - Args: - dtype (type): data type of the info coming from the env. - - Returns: - array (np.ndarray): the initialized info array. - array_mask (np.ndarray): the initialized boolean array. - - """ - if dtype in [int, float, bool] or issubclass(dtype, np.number): - array = np.zeros(self.num_envs, dtype=dtype) - else: - array = np.zeros(self.num_envs, dtype=object) - array[:] = None - array_mask = np.zeros(self.num_envs, dtype=bool) - return array, array_mask + return vector_infos def __del__(self): """Closes the vector environment.""" @@ -441,23 +444,23 @@ class VectorObservationWrapper(VectorWrapper): options: dict[str, Any] | None = None, ) -> tuple[ObsType, dict[str, Any]]: """Modifies the observation returned from the environment ``reset`` using the :meth:`observation`.""" - obs, info = self.env.reset(seed=seed, options=options) - return self.vector_observation(obs), info + observations, infos = self.env.reset(seed=seed, options=options) + return self.observation(observations), infos def step( self, actions: ActType ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]: """Modifies the observation returned from the environment ``step`` using the :meth:`observation`.""" - observation, reward, termination, truncation, info = self.env.step(actions) + observations, rewards, terminations, truncations, infos = self.env.step(actions) return ( - self.vector_observation(observation), - reward, - termination, - truncation, - self.update_final_obs(info), + self.observation(observations), + rewards, + terminations, + truncations, + infos, ) - def vector_observation(self, observation: ObsType) -> ObsType: + def observation(self, observation: ObsType) -> ObsType: """Defines the vector observation transformation. Args: @@ -468,25 +471,6 @@ class VectorObservationWrapper(VectorWrapper): """ raise NotImplementedError - def single_observation(self, observation: ObsType) -> ObsType: - """Defines the single observation transformation. - - Args: - observation: A single observation from the environment - - Returns: - The transformed observation - """ - raise NotImplementedError - - def update_final_obs(self, info: dict[str, Any]) -> dict[str, Any]: - """Updates the `final_obs` in the info using `single_observation`.""" - if "final_observation" in info: - for i, obs in enumerate(info["final_observation"]): - if obs is not None: - info["final_observation"][i] = self.single_observation(obs) - return info - class VectorActionWrapper(VectorWrapper): """Wraps the vectorized environment to allow a modular transformation of the actions. @@ -522,14 +506,14 @@ class VectorRewardWrapper(VectorWrapper): self, actions: ActType ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]: """Steps through the environment returning a reward modified by :meth:`reward`.""" - observation, reward, termination, truncation, info = self.env.step(actions) - return observation, self.rewards(reward), termination, truncation, info + observations, rewards, terminations, truncations, infos = self.env.step(actions) + return observations, self.rewards(rewards), terminations, truncations, infos - def rewards(self, reward: ArrayType) -> ArrayType: + def rewards(self, rewards: ArrayType) -> ArrayType: """Transform the reward before returning it. Args: - reward (array): the reward to transform + rewards (array): the reward to transform Returns: array: the transformed reward diff --git a/gymnasium/wrappers/common.py b/gymnasium/wrappers/common.py index 900131531..92a974fbe 100644 --- a/gymnasium/wrappers/common.py +++ b/gymnasium/wrappers/common.py @@ -15,7 +15,7 @@ from typing import TYPE_CHECKING, Any, SupportsFloat import gymnasium as gym from gymnasium import logger -from gymnasium.core import ActType, ObsType, RenderFrame +from gymnasium.core import ActType, ObsType, RenderFrame, WrapperObsType from gymnasium.error import ResetNeeded from gymnasium.utils.passive_env_checker import ( check_action_space, @@ -196,6 +196,15 @@ class Autoreset( gym.utils.RecordConstructorArgs.__init__(self) gym.Wrapper.__init__(self, env) + self.autoreset = False + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[WrapperObsType, dict[str, Any]]: + """Resets the environment and sets autoreset to False preventing.""" + self.autoreset = False + return super().reset(seed=seed, options=options) + def step( self, action: ActType ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: @@ -207,24 +216,13 @@ class Autoreset( Returns: The autoreset environment :meth:`step` """ - obs, reward, terminated, truncated, info = self.env.step(action) - - if terminated or truncated: - new_obs, new_info = self.env.reset() - - assert ( - "final_observation" not in new_info - ), f'new info dict already contains "final_observation", info keys: {new_info.keys()}' - assert ( - "final_info" not in new_info - ), f'new info dict already contains "final_observation", info keys: {new_info.keys()}' - - new_info["final_observation"] = obs - new_info["final_info"] = info - - obs = new_obs - info = new_info + if self.autoreset: + obs, info = self.env.reset() + reward, terminated, truncated = 0.0, False, False + else: + obs, reward, terminated, truncated, info = self.env.step(action) + self.autoreset = terminated or truncated return obs, reward, terminated, truncated, info @@ -470,14 +468,14 @@ class RecordEpisodeStatistics( def __init__( self, env: gym.Env[ObsType, ActType], - buffer_length: int | None = 100, + buffer_length: int = 100, stats_key: str = "episode", ): """This wrapper will keep track of cumulative rewards and episode lengths. Args: env (Env): The environment to apply the wrapper - buffer_length: The size of the buffers :attr:`return_queue` and :attr:`length_queue` + buffer_length: The size of the buffers :attr:`return_queue`, :attr:`length_queue` and :attr:`time_queue` stats_key: The info key for the episode statistics """ gym.utils.RecordConstructorArgs.__init__(self) @@ -520,6 +518,7 @@ class RecordEpisodeStatistics( self.length_queue.append(self.episode_lengths) self.episode_count += 1 + self.episode_start_time = time.perf_counter() return obs, reward, terminated, truncated, info diff --git a/gymnasium/wrappers/stateful_observation.py b/gymnasium/wrappers/stateful_observation.py index 5817d1a5a..6f211d72f 100644 --- a/gymnasium/wrappers/stateful_observation.py +++ b/gymnasium/wrappers/stateful_observation.py @@ -444,7 +444,7 @@ class NormalizeObservation( Change logs: * v0.21.0 - Initially add - * v1.0.0 - Add `update_running_mean` attribute to allow disabling of updating the running mean / standard + * v1.0.0 - Add `update_running_mean` attribute to allow disabling of updating the running mean / standard, particularly useful for evaluation time. """ def __init__(self, env: gym.Env[ObsType, ActType], epsilon: float = 1e-8): diff --git a/gymnasium/wrappers/vector/common.py b/gymnasium/wrappers/vector/common.py index 0ad5883d6..2e89fd5e1 100644 --- a/gymnasium/wrappers/vector/common.py +++ b/gymnasium/wrappers/vector/common.py @@ -62,14 +62,21 @@ class RecordEpisodeStatistics(VectorWrapper): None, None], dtype=object)} """ - def __init__(self, env: VectorEnv, deque_size: int = 100): + def __init__( + self, + env: VectorEnv, + deque_size: int = 100, + stats_key: str = "episode", + ): """This wrapper will keep track of cumulative rewards and episode lengths. Args: env (Env): The environment to apply the wrapper deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue` + stats_key: The info key to save the data """ super().__init__(env) + self._stats_key = stats_key self.episode_count = 0 @@ -77,6 +84,7 @@ class RecordEpisodeStatistics(VectorWrapper): self.episode_returns: np.ndarray = np.zeros(()) self.episode_lengths: np.ndarray = np.zeros(()) + self.time_queue = deque(maxlen=deque_size) self.return_queue = deque(maxlen=deque_size) self.length_queue = deque(maxlen=deque_size) @@ -88,11 +96,9 @@ class RecordEpisodeStatistics(VectorWrapper): """Resets the environment using kwargs and resets the episode returns and lengths.""" obs, info = super().reset(seed=seed, options=options) - self.episode_start_times = np.full( - self.num_envs, time.perf_counter(), dtype=np.float32 - ) - self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) - self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) + self.episode_start_times = np.full(self.num_envs, time.perf_counter()) + self.episode_returns = np.zeros(self.num_envs) + self.episode_lengths = np.zeros(self.num_envs) return obs, info @@ -110,7 +116,7 @@ class RecordEpisodeStatistics(VectorWrapper): assert isinstance( infos, dict - ), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order." + ), f"`vector.RecordEpisodeStatistics` requires `info` type to be `dict`, its actual type is {type(infos)}. This may be due to usage of other wrappers in the wrong order." self.episode_returns += rewards self.episode_lengths += 1 @@ -119,25 +125,25 @@ class RecordEpisodeStatistics(VectorWrapper): num_dones = np.sum(dones) if num_dones: - if "episode" in infos or "_episode" in infos: + if self._stats_key in infos or f"_{self._stats_key}" in infos: raise ValueError( - "Attempted to add episode stats when they already exist" + f"Attempted to add episode stats when they already exist, info keys: {list(infos.keys())}" ) else: - infos["episode"] = { + episode_time_length = np.round( + time.perf_counter() - self.episode_start_times, 6 + ) + infos[self._stats_key] = { "r": np.where(dones, self.episode_returns, 0.0), "l": np.where(dones, self.episode_lengths, 0), - "t": np.where( - dones, - np.round(time.perf_counter() - self.episode_start_times, 6), - 0.0, - ), + "t": np.where(dones, episode_time_length, 0.0), } - infos["_episode"] = dones + infos[f"_{self._stats_key}"] = dones self.episode_count += num_dones for i in np.where(dones): + self.time_queue.extend(episode_time_length[i]) self.return_queue.extend(self.episode_returns[i]) self.length_queue.extend(self.episode_lengths[i]) diff --git a/gymnasium/wrappers/vector/dict_info_to_list.py b/gymnasium/wrappers/vector/dict_info_to_list.py index 64b908ca1..3bcc4ab22 100644 --- a/gymnasium/wrappers/vector/dict_info_to_list.py +++ b/gymnasium/wrappers/vector/dict_info_to_list.py @@ -3,6 +3,8 @@ from __future__ import annotations from typing import Any +import numpy as np + from gymnasium.core import ActType, ObsType from gymnasium.vector.vector_env import ArrayType, VectorEnv, VectorWrapper @@ -78,6 +80,7 @@ class DictInfoToList(VectorWrapper): ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, list[dict[str, Any]]]: """Steps through the environment, convert dict info to list.""" observation, reward, terminated, truncated, infos = self.env.step(actions) + assert isinstance(infos, dict) list_info = self._convert_info_to_list(infos) return observation, reward, terminated, truncated, list_info @@ -90,11 +93,12 @@ class DictInfoToList(VectorWrapper): ) -> tuple[ObsType, list[dict[str, Any]]]: """Resets the environment using kwargs.""" obs, infos = self.env.reset(seed=seed, options=options) + assert isinstance(infos, dict) list_info = self._convert_info_to_list(infos) return obs, list_info - def _convert_info_to_list(self, infos: dict) -> list[dict[str, Any]]: + def _convert_info_to_list(self, vector_infos: dict) -> list[dict[str, Any]]: """Convert the dict info to list. Convert the dict info of the vectorized environment @@ -102,52 +106,28 @@ class DictInfoToList(VectorWrapper): has the info of the i-th environment. Args: - infos (dict): info dict coming from the env. + vector_infos (dict): info dict coming from the env. Returns: list_info (list): converted info. - """ list_info = [{} for _ in range(self.num_envs)] - list_info = self._process_episode_statistics(infos, list_info) - for k in infos: - if k.startswith("_"): + + for key, value in vector_infos.items(): + if key.startswith("_"): continue - for i, has_info in enumerate(infos[f"_{k}"]): - if has_info: - list_info[i][k] = infos[k][i] - return list_info - # todo - I think this function should be more general for any information - def _process_episode_statistics(self, infos: dict, list_info: list) -> list[dict]: - """Process episode statistics. - - `RecordEpisodeStatistics` wrapper add extra - information to the info. This information are in - the form of a dict of dict. This method process these - information and add them to the info. - `RecordEpisodeStatistics` info contains the keys - "r", "l", "t" which represents "cumulative reward", - "episode length", "elapsed time since instantiation of wrapper". - - Args: - infos (dict): infos coming from `RecordEpisodeStatistics`. - list_info (list): info of the current vectorized environment. - - Returns: - list_info (list): updated info. - - """ - episode_statistics = infos.pop("episode", False) - if not episode_statistics: - return list_info - - episode_statistics_mask = infos.pop("_episode") - for i, has_info in enumerate(episode_statistics_mask): - if has_info: - list_info[i]["episode"] = {} - list_info[i]["episode"]["r"] = episode_statistics["r"][i] - list_info[i]["episode"]["l"] = episode_statistics["l"][i] - list_info[i]["episode"]["t"] = episode_statistics["t"][i] + if isinstance(value, dict): + value_list_info = self._convert_info_to_list(value) + for env_num, (env_info, has_info) in enumerate( + zip(value_list_info, vector_infos[f"_{key}"]) + ): + if has_info: + list_info[env_num][key] = env_info + else: + assert isinstance(value, np.ndarray) + for env_num, has_info in enumerate(vector_infos[f"_{key}"]): + if has_info: + list_info[env_num][key] = value[env_num] return list_info diff --git a/gymnasium/wrappers/vector/stateful_observation.py b/gymnasium/wrappers/vector/stateful_observation.py index 1bfbf18c2..8b2808cd4 100644 --- a/gymnasium/wrappers/vector/stateful_observation.py +++ b/gymnasium/wrappers/vector/stateful_observation.py @@ -34,9 +34,9 @@ class NormalizeObservation(VectorObservationWrapper, gym.utils.RecordConstructor >>> for _ in range(100): ... obs, *_ = envs.step(envs.action_space.sample()) >>> np.mean(obs) - -0.017698428 + 0.024251968 >>> np.std(obs) - 0.62041104 + 0.62259156 >>> envs.close() Example with the normalize reward wrapper: @@ -48,9 +48,9 @@ class NormalizeObservation(VectorObservationWrapper, gym.utils.RecordConstructor >>> for _ in range(100): ... obs, *_ = envs.step(envs.action_space.sample()) >>> np.mean(obs) - -0.28381696 + -0.2359734 >>> np.std(obs) - 1.21742 + 1.1938739 >>> envs.close() """ @@ -81,29 +81,15 @@ class NormalizeObservation(VectorObservationWrapper, gym.utils.RecordConstructor """Sets the property to freeze/continue the running mean calculation of the observation statistics.""" self._update_running_mean = setting - def vector_observation(self, observation: ObsType) -> ObsType: + def observation(self, observations: ObsType) -> ObsType: """Defines the vector observation normalization function. Args: - observation: A vector observation from the environment + observations: A vector observation from the environment Returns: the normalized observation """ - return self._normalize_observations(observation) - - def single_observation(self, observation: ObsType) -> ObsType: - """Defines the single observation normalization function. - - Args: - observation: A single observation from the environment - - Returns: - The normalized observation - """ - return self._normalize_observations(observation[None]) - - def _normalize_observations(self, observations: ObsType) -> ObsType: if self._update_running_mean: self.obs_rms.update(observations) return (observations - self.obs_rms.mean) / np.sqrt( diff --git a/gymnasium/wrappers/vector/vectorize_observation.py b/gymnasium/wrappers/vector/vectorize_observation.py index 94cacbbe3..c0b8f6331 100644 --- a/gymnasium/wrappers/vector/vectorize_observation.py +++ b/gymnasium/wrappers/vector/vectorize_observation.py @@ -37,13 +37,10 @@ class TransformObservation(VectorObservationWrapper): >>> def scale_and_shift(obs): ... return (obs - 1.0) * 2.0 ... - >>> def vector_scale_and_shift(obs): - ... return (obs - 1.0) * 2.0 - ... >>> import gymnasium as gym >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync") >>> new_obs_space = Box(low=envs.observation_space.low, high=envs.observation_space.high) - >>> envs = TransformObservation(envs, single_func=scale_and_shift, vector_func=vector_scale_and_shift) + >>> envs = TransformObservation(envs, func=scale_and_shift, observation_space=new_obs_space) >>> obs, info = envs.reset(seed=123) >>> obs array([[-1.9635296, -2.0892358, -2.055928 , -2.0631256], @@ -55,16 +52,14 @@ class TransformObservation(VectorObservationWrapper): def __init__( self, env: VectorEnv, - vector_func: Callable[[ObsType], Any], - single_func: Callable[[ObsType], Any], + func: Callable[[ObsType], Any], observation_space: Space | None = None, ): """Constructor for the transform observation wrapper. Args: env: The vector environment to wrap - vector_func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``. - single_func: A function that will transform an individual observation, this function will be used for the final observation from the environment and is returned under ``info`` and not the normal observation. + func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``. observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``. """ super().__init__(env) @@ -72,16 +67,11 @@ class TransformObservation(VectorObservationWrapper): if observation_space is not None: self.observation_space = observation_space - self.vector_func = vector_func - self.single_func = single_func + self.func = func - def vector_observation(self, observation: ObsType) -> ObsType: + def observation(self, observations: ObsType) -> ObsType: """Apply function to the vector observation.""" - return self.vector_func(observation) - - def single_observation(self, observation: ObsType) -> ObsType: - """Apply function to the single observation.""" - return self.single_func(observation) + return self.func(observations) class VectorizeTransformObservation(VectorObservationWrapper): @@ -158,16 +148,16 @@ class VectorizeTransformObservation(VectorObservationWrapper): self.same_out = self.observation_space == self.env.observation_space self.out = create_empty_array(self.single_observation_space, self.num_envs) - def vector_observation(self, observation: ObsType) -> ObsType: + def observation(self, observations: ObsType) -> ObsType: """Iterates over the vector observations applying the single-agent wrapper ``observation`` then concatenates the observations together again.""" if self.same_out: return concatenate( self.single_observation_space, tuple( self.wrapper.func(obs) - for obs in iterate(self.observation_space, observation) + for obs in iterate(self.observation_space, observations) ), - observation, + observations, ) else: return deepcopy( @@ -175,16 +165,12 @@ class VectorizeTransformObservation(VectorObservationWrapper): self.single_observation_space, tuple( self.wrapper.func(obs) - for obs in iterate(self.env.observation_space, observation) + for obs in iterate(self.env.observation_space, observations) ), self.out, ) ) - def single_observation(self, observation: ObsType) -> ObsType: - """Transforms a single observation using the wrapper transformation function.""" - return self.wrapper.func(observation) - class FilterObservation(VectorizeTransformObservation): """Vector wrapper for filtering dict or tuple observation spaces. diff --git a/tests/envs/functional/test_jax.py b/tests/envs/functional/test_jax.py index 170ba9ed2..01b8ad198 100644 --- a/tests/envs/functional/test_jax.py +++ b/tests/envs/functional/test_jax.py @@ -6,8 +6,14 @@ import jax.numpy as jnp # noqa: E402 import jax.random as jrng # noqa: E402 import numpy as np # noqa: E402 -from gymnasium.envs.phys2d.cartpole import CartPoleFunctional # noqa: E402 -from gymnasium.envs.phys2d.pendulum import PendulumFunctional # noqa: E402 +from gymnasium.envs.phys2d.cartpole import ( # noqa: E402 + CartPoleFunctional, + CartPoleJaxVectorEnv, +) +from gymnasium.envs.phys2d.pendulum import ( # noqa: E402 + PendulumFunctional, + PendulumJaxVectorEnv, +) @pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional]) @@ -105,3 +111,34 @@ def test_vmap(env_class): assert obs.dtype == jnp.float32 state = next_state + + +@pytest.mark.parametrize("env_class", [CartPoleJaxVectorEnv, PendulumJaxVectorEnv]) +def test_vectorized(env_class): + env = env_class(num_envs=10) + env.action_space.seed(0) + + obs, info = env.reset(seed=0) + assert obs.shape == (10,) + env.single_observation_space.shape + assert isinstance(obs, np.ndarray) + assert isinstance(info, dict) + + for t in range(100): + action = env.action_space.sample() + obs, reward, terminated, truncated, info = env.step(action) + + assert obs.shape == (10,) + env.single_observation_space.shape + assert isinstance(obs, np.ndarray) + assert reward.shape == (10,) + assert isinstance(reward, np.ndarray) + assert terminated.shape == (10,) + assert isinstance(terminated, np.ndarray) + assert truncated.shape == (10,) + assert isinstance(truncated, np.ndarray) + assert isinstance(info, dict) + + # These were removed in the new autoreset order + assert "final_observation" not in info + assert "final_info" not in info + assert "_final_observation" not in info + assert "_final_info" not in info diff --git a/tests/vector/test_async_vector_env.py b/tests/vector/test_async_vector_env.py index 6f2609d8d..478b5b5d3 100644 --- a/tests/vector/test_async_vector_env.py +++ b/tests/vector/test_async_vector_env.py @@ -33,7 +33,7 @@ def test_create_async_vector_env(shared_memory): @pytest.mark.parametrize("shared_memory", [True, False]) def test_reset_async_vector_env(shared_memory): - """Test the reset of an sync vector environment with or without shared memory.""" + """Test the reset of async vector environment with or without shared memory.""" env_fns = [make_env("CartPole-v1", i) for i in range(8)] env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) diff --git a/tests/vector/test_vector_env.py b/tests/vector/test_vector_env.py index 211c36882..dba5cdd51 100644 --- a/tests/vector/test_vector_env.py +++ b/tests/vector/test_vector_env.py @@ -6,6 +6,7 @@ import numpy as np import pytest from gymnasium.spaces import Discrete +from gymnasium.utils.env_checker import data_equivalence from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv from tests.testing_env import GenericTestEnv from tests.vector.testing_utils import make_env @@ -29,6 +30,7 @@ def test_vector_env_equal(shared_memory): async_observations, async_infos = async_env.reset(seed=0) sync_observations, sync_infos = sync_env.reset(seed=0) assert np.all(async_observations == sync_observations) + assert data_equivalence(async_infos, sync_infos) for _ in range(num_steps): actions = async_env.action_space.sample() @@ -49,16 +51,11 @@ def test_vector_env_equal(shared_memory): sync_infos, ) = sync_env.step(actions) - if any(sync_terminations) or any(sync_truncations): - assert "final_observation" in async_infos - assert "_final_observation" in async_infos - assert "final_observation" in sync_infos - assert "_final_observation" in sync_infos - assert np.all(async_observations == sync_observations) assert np.all(async_rewards == sync_rewards) assert np.all(async_terminations == sync_terminations) assert np.all(async_truncations == sync_truncations) + assert data_equivalence(async_infos, sync_infos) async_env.close() sync_env.close() @@ -115,14 +112,13 @@ def test_final_obs_info(vectoriser): ) obs, _, termination, _, info = env.step([3]) + assert obs == np.array([0]) and info == {"action": 3, "_action": np.array([True])} + + obs, _, terminated, _, info = env.step([4]) assert ( obs == np.array([0]) and termination == np.array([True]) and info["reset"] == np.array([True]) ) - assert "final_observation" in info and "final_info" in info - assert info["final_observation"] == np.array([0]) and info["final_info"] == { - "action": 3 - } env.close() diff --git a/tests/vector/test_vector_env_info.py b/tests/vector/test_vector_env_info.py index 6711aa12d..09bed45c2 100644 --- a/tests/vector/test_vector_env_info.py +++ b/tests/vector/test_vector_env_info.py @@ -1,66 +1,164 @@ """Test the vector environment information.""" +from __future__ import annotations + +from typing import Any, SupportsFloat + import numpy as np import pytest import gymnasium as gym -from gymnasium.vector.sync_vector_env import SyncVectorEnv -from tests.vector.testing_utils import make_env +from gymnasium.core import ActType, ObsType +from gymnasium.spaces import Box, Discrete +from gymnasium.utils.env_checker import data_equivalence +from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv, VectorEnv -ENV_ID = "CartPole-v1" -NUM_ENVS = 3 -ENV_STEPS = 50 -SEED = 42 +def test_vector_add_info(): + env = VectorEnv() + + # Test num-envs==1 then expand_dims(sub-env-info) == vector-infos + env.num_envs = 1 + sub_env_info = {"a": 0, "b": 0.0, "c": None, "d": np.zeros((2,)), "e": Discrete(1)} + vector_infos = env._add_info({}, sub_env_info, 0) + expected_vector_infos = { + "a": np.array([0]), + "b": np.array([0.0]), + "c": np.array([None], dtype=object), + "d": np.zeros( + ( + 1, + 2, + ) + ), + "e": np.array([Discrete(1)], dtype=object), + "_a": np.array([True]), + "_b": np.array([True]), + "_c": np.array([True]), + "_d": np.array([True]), + "_e": np.array([True]), + } + assert data_equivalence(vector_infos, expected_vector_infos) + + # Thought: num-envs>1 then vector-infos should have the same structure as sub-env-info + env.num_envs = 3 + sub_env_infos = [ + {"a": 0, "b": 0.0, "c": None, "d": np.zeros((2,)), "e": Discrete(1)}, + {"a": 1, "b": 1.0, "c": None, "d": np.zeros((2,)), "e": Discrete(2)}, + {"a": 2, "b": 2.0, "c": None, "d": np.zeros((2,)), "e": Discrete(3)}, + ] + + vector_infos = {} + for i, info in enumerate(sub_env_infos): + vector_infos = env._add_info(vector_infos, info, i) + + expected_vector_infos = { + "a": np.array([0, 1, 2]), + "b": np.array([0.0, 1.0, 2.0]), + "c": np.array([None, None, None], dtype=object), + "d": np.zeros((3, 2)), + "e": np.array([Discrete(1), Discrete(2), Discrete(3)], dtype=object), + "_a": np.array([True, True, True]), + "_b": np.array([True, True, True]), + "_c": np.array([True, True, True]), + "_d": np.array([True, True, True]), + "_e": np.array([True, True, True]), + } + assert data_equivalence(vector_infos, expected_vector_infos) + + # Test different structures of sub-infos + env.num_envs = 3 + sub_env_infos = [ + {"a": 1, "b": 1.0}, + {"c": None, "d": np.zeros((2,))}, + {"e": Discrete(3)}, + ] + + vector_infos = {} + for i, info in enumerate(sub_env_infos): + vector_infos = env._add_info(vector_infos, info, i) + + expected_vector_infos = { + "a": np.array([1, 0, 0]), + "b": np.array([1.0, 0.0, 0.0]), + "c": np.array([None, None, None], dtype=object), + "d": np.zeros((3, 2)), + "e": np.array([None, None, Discrete(3)], dtype=object), + "_a": np.array([True, False, False]), + "_b": np.array([True, False, False]), + "_c": np.array([False, True, False]), + "_d": np.array([False, True, False]), + "_e": np.array([False, False, True]), + } + assert data_equivalence(vector_infos, expected_vector_infos) + + # Test recursive structure + env.num_envs = 3 + sub_env_infos = [ + {"episode": {"a": 1, "b": 1.0}}, + {"episode": {"a": 2, "b": 2.0}, "a": 1}, + {"a": 2}, + ] + + vector_infos = {} + for i, info in enumerate(sub_env_infos): + vector_infos = env._add_info(vector_infos, info, i) + + expected_vector_infos = { + "episode": { + "a": np.array([1, 2, 0]), + "b": np.array([1.0, 2.0, 0.0]), + "_a": np.array([True, True, False]), + "_b": np.array([True, True, False]), + }, + "_episode": np.array([True, True, False]), + "a": np.array([0, 1, 2]), + "_a": np.array([False, True, True]), + } + assert data_equivalence(vector_infos, expected_vector_infos) -@pytest.mark.parametrize("vectorization_mode", ["async", "sync"]) -def test_vector_env_info(vectorization_mode: str): - """Test vector environment info for different vectorization modes.""" - env = gym.make_vec( - ENV_ID, - num_envs=NUM_ENVS, - vectorization_mode=vectorization_mode, +class ReturnInfoEnv(gym.Env): + def __init__(self, infos): + self.observation_space = Box(0, 1) + self.action_space = Box(0, 1) + + self.infos = infos + + def reset( + self, + *, + seed: int | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[ObsType, dict[str, Any]]: + return self.observation_space.sample(), self.infos[0] + + def step( + self, action: ActType + ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: + return self.observation_space.sample(), 0, True, False, self.infos[1] + + +@pytest.mark.parametrize("vectorizer", [AsyncVectorEnv, SyncVectorEnv]) +def test_vectorizers(vectorizer): + vec_env = vectorizer( + [ + lambda: ReturnInfoEnv([{"a": 1}, {"c": np.array([1, 2])}]), + lambda: ReturnInfoEnv([{"a": 2, "b": 3}, {"c": np.array([3, 4])}]), + ] ) - env.reset(seed=SEED) - for _ in range(ENV_STEPS): - env.action_space.seed(SEED) - action = env.action_space.sample() - _, _, terminateds, truncateds, infos = env.step(action) - if any(terminateds) or any(truncateds): - assert len(infos["final_observation"]) == NUM_ENVS - assert len(infos["_final_observation"]) == NUM_ENVS - assert isinstance(infos["final_observation"], np.ndarray) - assert isinstance(infos["_final_observation"], np.ndarray) + reset_expected_infos = { + "a": np.array([1, 2]), + "b": np.array([0, 3]), + "_a": np.array([True, True]), + "_b": np.array([False, True]), + } + step_expected_infos = { + "c": np.array([[1, 2], [3, 4]]), + "_c": np.array([True, True]), + } - for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)): - if terminated or truncated: - assert infos["_final_observation"][i] - else: - assert not infos["_final_observation"][i] - assert infos["final_observation"][i] is None - - env.close() - - -@pytest.mark.parametrize("concurrent_ends", [1, 2, 3]) -def test_vector_env_info_concurrent_termination(concurrent_ends): - """Test the vector environment information works with concurrent termination.""" - # envs that need to terminate together will have the same action - actions = [0] * concurrent_ends + [1] * (NUM_ENVS - concurrent_ends) - envs = [make_env(ENV_ID, SEED) for _ in range(NUM_ENVS)] - envs = SyncVectorEnv(envs) - - for _ in range(ENV_STEPS): - _, _, terminateds, truncateds, infos = envs.step(actions) - if any(terminateds) or any(truncateds): - for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)): - if i < concurrent_ends: - assert terminated or truncated - assert infos["_final_observation"][i] - else: - assert not infos["_final_observation"][i] - assert infos["final_observation"][i] is None - return - - envs.close() + _, reset_info = vec_env.reset() + assert data_equivalence(reset_info, reset_expected_infos) + _, _, _, _, step_info = vec_env.step(vec_env.action_space.sample()) + assert data_equivalence(step_info, step_expected_infos) diff --git a/tests/wrappers/test_autoreset.py b/tests/wrappers/test_autoreset.py index 1431b061f..1448264f0 100644 --- a/tests/wrappers/test_autoreset.py +++ b/tests/wrappers/test_autoreset.py @@ -45,19 +45,15 @@ def test_autoreset_wrapper_autoreset(): assert info == {"count": 2} obs, reward, terminated, truncated, info = env.step(action) - assert obs == np.array([0]) + assert obs == np.array([3]) assert (terminated or truncated) is True assert reward == 1 - assert info == { - "count": 0, - "final_observation": np.array([3]), - "final_info": {"count": 3}, - } + assert info == {"count": 3} obs, reward, terminated, truncated, info = env.step(action) - assert obs == np.array([1]) + assert obs == np.array([0]) assert reward == 0 assert (terminated or truncated) is False - assert info == {"count": 1} + assert info == {"count": 0} env.close() diff --git a/tests/wrappers/vector/test_dict_info_to_list.py b/tests/wrappers/vector/test_dict_info_to_list.py index cb09029a9..9e61f29b9 100644 --- a/tests/wrappers/vector/test_dict_info_to_list.py +++ b/tests/wrappers/vector/test_dict_info_to_list.py @@ -1,21 +1,22 @@ """Test suite for DictInfoTolist wrapper.""" +from __future__ import annotations + +from typing import Any import numpy as np import pytest import gymnasium as gym -from gymnasium.wrappers.vector import DictInfoToList, RecordEpisodeStatistics +from gymnasium.core import ObsType +from gymnasium.spaces import Discrete +from gymnasium.utils.env_checker import data_equivalence +from gymnasium.vector import VectorEnv +from gymnasium.wrappers.vector import DictInfoToList -ENV_ID = "CartPole-v1" -NUM_ENVS = 3 -ENV_STEPS = 50 -SEED = 42 - - -def test_usage_in_vector_env(): - env = gym.make(ENV_ID, disable_env_checker=True) - vector_env = gym.make_vec(ENV_ID, num_envs=NUM_ENVS, vectorization_mode="sync") +def test_usage_in_vector_env(env_id: str = "CartPole-v1", num_envs: int = 3): + env = gym.make(env_id, disable_env_checker=True) + vector_env = gym.make_vec(env_id, num_envs=num_envs) DictInfoToList(vector_env) @@ -23,40 +24,140 @@ def test_usage_in_vector_env(): DictInfoToList(env) -def test_info_to_list(): - env_to_wrap = gym.make_vec(ENV_ID, num_envs=NUM_ENVS, vectorization_mode="sync") - wrapped_env = DictInfoToList(env_to_wrap) - wrapped_env.action_space.seed(SEED) - _, info = wrapped_env.reset(seed=SEED) - assert isinstance(info, list) - assert len(info) == NUM_ENVS +class ResetOptionAsInfo(VectorEnv): + """Minimal implementation to test the conversion of vector dict info to list info.""" - for _ in range(ENV_STEPS): - action = wrapped_env.action_space.sample() - _, _, terminateds, truncateds, list_info = wrapped_env.step(action) - for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)): - if terminated or truncated: - assert "final_observation" in list_info[i] - else: - assert "final_observation" not in list_info[i] + def reset( + self, + *, + seed: int | list[int] | None = None, + options: dict[str, Any] | None = None, # options are passed are the info output + ) -> tuple[ObsType, dict[str, Any]]: + return None, options -def test_info_to_list_statistics(): - env_to_wrap = gym.make_vec(ENV_ID, num_envs=NUM_ENVS, vectorization_mode="sync") - wrapped_env = DictInfoToList(RecordEpisodeStatistics(env_to_wrap)) - _, info = wrapped_env.reset(seed=SEED) - wrapped_env.action_space.seed(SEED) - assert isinstance(info, list) - assert len(info) == NUM_ENVS +def test_update_info(): + env = DictInfoToList(ResetOptionAsInfo()) - for _ in range(ENV_STEPS): - action = wrapped_env.action_space.sample() - _, _, terminateds, truncateds, list_info = wrapped_env.step(action) - for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)): - if terminated or truncated: - assert "episode" in list_info[i] - for stats in ["r", "l", "t"]: - assert stats in list_info[i]["episode"] - assert np.isscalar(list_info[i]["episode"][stats]) - else: - assert "episode" not in list_info[i] + # Test num-envs==1 then expand_dims(sub-env-info) == vector-infos + env.unwrapped.num_envs = 1 + + vector_infos = { + "a": np.array([0]), + "b": np.array([0.0]), + "c": np.array([None], dtype=object), + "d": np.zeros( + ( + 1, + 2, + ) + ), + "e": np.array([Discrete(1)], dtype=object), + "_a": np.array([True]), + "_b": np.array([True]), + "_c": np.array([True]), + "_d": np.array([True]), + "_e": np.array([True]), + } + _, list_info = env.reset(options=vector_infos) + expected_list_info = [ + { + "a": np.int64(0), + "b": np.float64(0.0), + "c": None, + "d": np.zeros((2,)), + "e": Discrete(1), + } + ] + + assert data_equivalence(list_info, expected_list_info) + + # Thought: num-envs>1 then vector-infos should have the same structure as sub-env-info + env.unwrapped.num_envs = 3 + + vector_infos = { + "a": np.array([0, 1, 2]), + "b": np.array([0.0, 1.0, 2.0]), + "c": np.array([None, None, None], dtype=object), + "d": np.zeros((3, 2)), + "e": np.array([Discrete(1), Discrete(2), Discrete(3)], dtype=object), + "_a": np.array([True, True, True]), + "_b": np.array([True, True, True]), + "_c": np.array([True, True, True]), + "_d": np.array([True, True, True]), + "_e": np.array([True, True, True]), + } + _, list_info = env.reset(options=vector_infos) + expected_list_info = [ + { + "a": np.int64(0), + "b": np.float64(0.0), + "c": None, + "d": np.zeros((2,)), + "e": Discrete(1), + }, + { + "a": np.int64(1), + "b": np.float64(1.0), + "c": None, + "d": np.zeros((2,)), + "e": Discrete(2), + }, + { + "a": np.int64(2), + "b": np.float64(2.0), + "c": None, + "d": np.zeros((2,)), + "e": Discrete(3), + }, + ] + + assert list_info[0].keys() == expected_list_info[0].keys() + for key in list_info[0].keys(): + assert data_equivalence(list_info[0][key], expected_list_info[0][key]) + assert data_equivalence(list_info, expected_list_info) + + # Test different structures of sub-infos + env.unwrapped.num_envs = 3 + + vector_infos = { + "a": np.array([1, 0, 0]), + "_a": np.array([True, False, False]), + "b": np.array([1.0, 0.0, 0.0]), + "_b": np.array([True, False, False]), + "c": np.array([None, None, None], dtype=object), + "_c": np.array([False, True, False]), + "_d": np.array([False, True, False]), + "d": np.zeros((3, 2)), + "e": np.array([None, None, Discrete(3)], dtype=object), + "_e": np.array([False, False, True]), + } + _, list_info = env.reset(options=vector_infos) + expected_list_info = [ + {"a": np.int64(1), "b": np.float64(1.0)}, + {"c": None, "d": np.zeros((2,))}, + {"e": Discrete(3)}, + ] + assert data_equivalence(list_info, expected_list_info) + + # Test recursive structure + env.unwrapped.num_envs = 3 + + vector_infos = { + "episode": { + "a": np.array([1, 2, 0]), + "b": np.array([1.0, 2.0, 0.0]), + "_a": np.array([True, True, False]), + "_b": np.array([True, True, False]), + }, + "_episode": np.array([True, True, False]), + "a": np.array([0, 1, 2]), + "_a": np.array([False, True, True]), + } + _, list_info = env.reset(options=vector_infos) + expected_list_info = [ + {"episode": {"a": np.int64(1), "b": np.float64(1.0)}}, + {"episode": {"a": np.int64(2), "b": np.float64(2.0)}, "a": np.int64(1)}, + {"a": np.int64(2)}, + ] + assert data_equivalence(list_info, expected_list_info) diff --git a/tests/wrappers/vector/test_record_episode_statistics.py b/tests/wrappers/vector/test_record_episode_statistics.py new file mode 100644 index 000000000..4de1b3fa6 --- /dev/null +++ b/tests/wrappers/vector/test_record_episode_statistics.py @@ -0,0 +1,71 @@ +import pytest + +import gymnasium as gym +from gymnasium.utils.env_checker import data_equivalence +from gymnasium.vector import VectorEnv + + +@pytest.mark.parametrize("num_envs", (1, 3)) +def test_record_episode_statistics(num_envs, env_id="CartPole-v1", num_steps=100): + wrapper_vector_env: VectorEnv = gym.wrappers.vector.RecordEpisodeStatistics( + gym.make_vec(id=env_id, num_envs=num_envs, vectorization_mode="sync"), + ) + vector_wrapper_env = gym.make_vec( + id=env_id, + num_envs=num_envs, + vectorization_mode="sync", + wrappers=(gym.wrappers.RecordEpisodeStatistics,), + ) + + assert wrapper_vector_env.action_space == vector_wrapper_env.action_space + assert wrapper_vector_env.observation_space == vector_wrapper_env.observation_space + assert ( + wrapper_vector_env.single_action_space == vector_wrapper_env.single_action_space + ) + assert ( + wrapper_vector_env.single_observation_space + == vector_wrapper_env.single_observation_space + ) + + assert wrapper_vector_env.num_envs == vector_wrapper_env.num_envs + + wrapper_vector_obs, wrapper_vector_info = wrapper_vector_env.reset(seed=123) + vector_wrapper_obs, vector_wrapper_info = vector_wrapper_env.reset(seed=123) + + assert data_equivalence(wrapper_vector_obs, vector_wrapper_obs) + assert data_equivalence(wrapper_vector_info, vector_wrapper_info) + + for _ in range(num_steps): + action = wrapper_vector_env.action_space.sample() + ( + wrapper_vector_obs, + wrapper_vector_reward, + wrapper_vector_terminated, + wrapper_vector_truncated, + wrapper_vector_info, + ) = wrapper_vector_env.step(action) + ( + vector_wrapper_obs, + vector_wrapper_reward, + vector_wrapper_terminated, + vector_wrapper_truncated, + vector_wrapper_info, + ) = vector_wrapper_env.step(action) + + data_equivalence(wrapper_vector_obs, vector_wrapper_obs) + data_equivalence(wrapper_vector_reward, vector_wrapper_reward) + data_equivalence(wrapper_vector_terminated, vector_wrapper_terminated) + data_equivalence(wrapper_vector_truncated, vector_wrapper_truncated) + + if "episode" in wrapper_vector_info: + assert "episode" in vector_wrapper_info + + wrapper_vector_time = wrapper_vector_info["episode"].pop("t") + vector_wrapper_time = vector_wrapper_info["episode"].pop("t") + assert wrapper_vector_time.shape == vector_wrapper_time.shape + assert wrapper_vector_time.dtype == vector_wrapper_time.dtype + + data_equivalence(wrapper_vector_info, vector_wrapper_info) + + wrapper_vector_env.close() + vector_wrapper_env.close() diff --git a/tests/wrappers/vector/test_vector_wrappers.py b/tests/wrappers/vector/test_vector_wrappers.py index ed6e85e97..cf0db8b69 100644 --- a/tests/wrappers/vector/test_vector_wrappers.py +++ b/tests/wrappers/vector/test_vector_wrappers.py @@ -42,21 +42,21 @@ def custom_environments(): ("CustomDictEnv-v0", "FilterObservation", {"filter_keys": ["a"]}), ("CartPole-v1", "FlattenObservation", {}), ("CarRacing-v2", "GrayscaleObservation", {}), - # ("CarRacing-v2", "ResizeObservation", {"shape": (35, 45)}), + ("CarRacing-v2", "ResizeObservation", {"shape": (35, 45)}), ("CarRacing-v2", "ReshapeObservation", {"shape": (96, 48, 6)}), ("CartPole-v1", "RescaleObservation", {"min_obs": 0, "max_obs": 1}), ("CartPole-v1", "DtypeObservation", {"dtype": np.int32}), - # ("CartPole-v1", "RenderObservation", {}), - # ("CartPole-v1", "TimeAwareObservation", {}), - # ("CartPole-v1", "FrameStackObservation", {}), - # ("CartPole-v1", "DelayObservation", {}), + # ("CartPole-v1", "RenderObservation", {}), # not implemented + # ("CartPole-v1", "TimeAwareObservation", {}), # not implemented + # ("CartPole-v1", "FrameStackObservation", {}), # not implemented + # ("CartPole-v1", "DelayObservation", {}), # not implemented ("MountainCarContinuous-v0", "ClipAction", {}), ( "MountainCarContinuous-v0", "RescaleAction", {"min_action": 1, "max_action": 2}, ), - ("CartPole-v1", "ClipReward", {"min_reward": 0.25, "max_reward": 0.75}), + ("CartPole-v1", "ClipReward", {"min_reward": -0.25, "max_reward": 0.75}), ), ) def test_vector_wrapper_equivalence(