mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-31 18:12:53 +00:00
Change autoreset order (#808)
Co-authored-by: pseudo-rnd-thoughts <mark.m.towers@gmail.com>
This commit is contained in:
committed by
GitHub
parent
967bbf5823
commit
e9c66e4225
@@ -231,7 +231,7 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
|
||||
return self
|
||||
|
||||
def _add_info(
|
||||
self, infos: dict[str, Any], info: dict[str, Any], env_num: int
|
||||
self, vector_infos: dict[str, Any], env_info: dict[str, Any], env_num: int
|
||||
) -> dict[str, Any]:
|
||||
"""Add env info to the info dictionary of the vectorized environment.
|
||||
|
||||
@@ -241,48 +241,51 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
|
||||
whether or not the i-indexed environment has this `info`.
|
||||
|
||||
Args:
|
||||
infos (dict): the infos of the vectorized environment
|
||||
info (dict): the info coming from the single environment
|
||||
vector_infos (dict): the infos of the vectorized environment
|
||||
env_info (dict): the info coming from the single environment
|
||||
env_num (int): the index of the single environment
|
||||
|
||||
Returns:
|
||||
infos (dict): the (updated) infos of the vectorized environment
|
||||
|
||||
"""
|
||||
for k in info.keys():
|
||||
if k not in infos:
|
||||
info_array, array_mask = self._init_info_arrays(type(info[k]))
|
||||
for key, value in env_info.items():
|
||||
# If value is a dictionary, then we apply the `_add_info` recursively.
|
||||
if isinstance(value, dict):
|
||||
array = self._add_info(vector_infos.get(key, {}), value, env_num)
|
||||
# Otherwise, we are a base case to group the data
|
||||
else:
|
||||
info_array, array_mask = infos[k], infos[f"_{k}"]
|
||||
# If the key doesn't exist in the vector infos, then we can create an array of that batch type
|
||||
if key not in vector_infos:
|
||||
if type(value) in [int, float, bool] or issubclass(
|
||||
type(value), np.number
|
||||
):
|
||||
array = np.zeros(self.num_envs, dtype=type(value))
|
||||
elif isinstance(value, np.ndarray):
|
||||
# We assume that all instances of the np.array info are of the same shape
|
||||
array = np.zeros(
|
||||
(self.num_envs, *value.shape), dtype=value.dtype
|
||||
)
|
||||
else:
|
||||
# For unknown objects, we use a Numpy object array
|
||||
array = np.full(self.num_envs, fill_value=None, dtype=object)
|
||||
# Otherwise, just use the array that already exists
|
||||
else:
|
||||
array = vector_infos[key]
|
||||
|
||||
info_array[env_num], array_mask[env_num] = info[k], True
|
||||
infos[k], infos[f"_{k}"] = info_array, array_mask
|
||||
return infos
|
||||
# Assign the data in the `env_num` position
|
||||
# We only want to run this for the base-case data (not recursive data forcing the ugly function structure)
|
||||
array[env_num] = value
|
||||
|
||||
def _init_info_arrays(self, dtype: type) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Initialize the info array.
|
||||
# Get the array mask and if it doesn't already exist then create a zero bool array
|
||||
array_mask = vector_infos.get(
|
||||
f"_{key}", np.zeros(self.num_envs, dtype=np.bool_)
|
||||
)
|
||||
array_mask[env_num] = True
|
||||
|
||||
Initialize the info array. If the dtype is numeric
|
||||
the info array will have the same dtype, otherwise
|
||||
will be an array of `None`. Also, a boolean array
|
||||
of the same length is returned. It will be used for
|
||||
assessing which environment has info data.
|
||||
# Update the vector info with the updated data and mask information
|
||||
vector_infos[key], vector_infos[f"_{key}"] = array, array_mask
|
||||
|
||||
Args:
|
||||
dtype (type): data type of the info coming from the env.
|
||||
|
||||
Returns:
|
||||
array (np.ndarray): the initialized info array.
|
||||
array_mask (np.ndarray): the initialized boolean array.
|
||||
|
||||
"""
|
||||
if dtype in [int, float, bool] or issubclass(dtype, np.number):
|
||||
array = np.zeros(self.num_envs, dtype=dtype)
|
||||
else:
|
||||
array = np.zeros(self.num_envs, dtype=object)
|
||||
array[:] = None
|
||||
array_mask = np.zeros(self.num_envs, dtype=bool)
|
||||
return array, array_mask
|
||||
return vector_infos
|
||||
|
||||
def __del__(self):
|
||||
"""Closes the vector environment."""
|
||||
@@ -441,23 +444,23 @@ class VectorObservationWrapper(VectorWrapper):
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[ObsType, dict[str, Any]]:
|
||||
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
|
||||
obs, info = self.env.reset(seed=seed, options=options)
|
||||
return self.vector_observation(obs), info
|
||||
observations, infos = self.env.reset(seed=seed, options=options)
|
||||
return self.observation(observations), infos
|
||||
|
||||
def step(
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
|
||||
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
|
||||
observation, reward, termination, truncation, info = self.env.step(actions)
|
||||
observations, rewards, terminations, truncations, infos = self.env.step(actions)
|
||||
return (
|
||||
self.vector_observation(observation),
|
||||
reward,
|
||||
termination,
|
||||
truncation,
|
||||
self.update_final_obs(info),
|
||||
self.observation(observations),
|
||||
rewards,
|
||||
terminations,
|
||||
truncations,
|
||||
infos,
|
||||
)
|
||||
|
||||
def vector_observation(self, observation: ObsType) -> ObsType:
|
||||
def observation(self, observation: ObsType) -> ObsType:
|
||||
"""Defines the vector observation transformation.
|
||||
|
||||
Args:
|
||||
@@ -468,25 +471,6 @@ class VectorObservationWrapper(VectorWrapper):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def single_observation(self, observation: ObsType) -> ObsType:
|
||||
"""Defines the single observation transformation.
|
||||
|
||||
Args:
|
||||
observation: A single observation from the environment
|
||||
|
||||
Returns:
|
||||
The transformed observation
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def update_final_obs(self, info: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Updates the `final_obs` in the info using `single_observation`."""
|
||||
if "final_observation" in info:
|
||||
for i, obs in enumerate(info["final_observation"]):
|
||||
if obs is not None:
|
||||
info["final_observation"][i] = self.single_observation(obs)
|
||||
return info
|
||||
|
||||
|
||||
class VectorActionWrapper(VectorWrapper):
|
||||
"""Wraps the vectorized environment to allow a modular transformation of the actions.
|
||||
@@ -522,14 +506,14 @@ class VectorRewardWrapper(VectorWrapper):
|
||||
self, actions: ActType
|
||||
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
|
||||
"""Steps through the environment returning a reward modified by :meth:`reward`."""
|
||||
observation, reward, termination, truncation, info = self.env.step(actions)
|
||||
return observation, self.rewards(reward), termination, truncation, info
|
||||
observations, rewards, terminations, truncations, infos = self.env.step(actions)
|
||||
return observations, self.rewards(rewards), terminations, truncations, infos
|
||||
|
||||
def rewards(self, reward: ArrayType) -> ArrayType:
|
||||
def rewards(self, rewards: ArrayType) -> ArrayType:
|
||||
"""Transform the reward before returning it.
|
||||
|
||||
Args:
|
||||
reward (array): the reward to transform
|
||||
rewards (array): the reward to transform
|
||||
|
||||
Returns:
|
||||
array: the transformed reward
|
||||
|
Reference in New Issue
Block a user