Change autoreset order (#808)

Co-authored-by: pseudo-rnd-thoughts <mark.m.towers@gmail.com>
This commit is contained in:
Ariel Kwiatkowski
2023-12-03 19:50:18 +01:00
committed by GitHub
parent 967bbf5823
commit e9c66e4225
18 changed files with 591 additions and 364 deletions

View File

@@ -231,7 +231,7 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
return self
def _add_info(
self, infos: dict[str, Any], info: dict[str, Any], env_num: int
self, vector_infos: dict[str, Any], env_info: dict[str, Any], env_num: int
) -> dict[str, Any]:
"""Add env info to the info dictionary of the vectorized environment.
@@ -241,48 +241,51 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
whether or not the i-indexed environment has this `info`.
Args:
infos (dict): the infos of the vectorized environment
info (dict): the info coming from the single environment
vector_infos (dict): the infos of the vectorized environment
env_info (dict): the info coming from the single environment
env_num (int): the index of the single environment
Returns:
infos (dict): the (updated) infos of the vectorized environment
"""
for k in info.keys():
if k not in infos:
info_array, array_mask = self._init_info_arrays(type(info[k]))
for key, value in env_info.items():
# If value is a dictionary, then we apply the `_add_info` recursively.
if isinstance(value, dict):
array = self._add_info(vector_infos.get(key, {}), value, env_num)
# Otherwise, we are a base case to group the data
else:
info_array, array_mask = infos[k], infos[f"_{k}"]
# If the key doesn't exist in the vector infos, then we can create an array of that batch type
if key not in vector_infos:
if type(value) in [int, float, bool] or issubclass(
type(value), np.number
):
array = np.zeros(self.num_envs, dtype=type(value))
elif isinstance(value, np.ndarray):
# We assume that all instances of the np.array info are of the same shape
array = np.zeros(
(self.num_envs, *value.shape), dtype=value.dtype
)
else:
# For unknown objects, we use a Numpy object array
array = np.full(self.num_envs, fill_value=None, dtype=object)
# Otherwise, just use the array that already exists
else:
array = vector_infos[key]
info_array[env_num], array_mask[env_num] = info[k], True
infos[k], infos[f"_{k}"] = info_array, array_mask
return infos
# Assign the data in the `env_num` position
# We only want to run this for the base-case data (not recursive data forcing the ugly function structure)
array[env_num] = value
def _init_info_arrays(self, dtype: type) -> tuple[np.ndarray, np.ndarray]:
"""Initialize the info array.
# Get the array mask and if it doesn't already exist then create a zero bool array
array_mask = vector_infos.get(
f"_{key}", np.zeros(self.num_envs, dtype=np.bool_)
)
array_mask[env_num] = True
Initialize the info array. If the dtype is numeric
the info array will have the same dtype, otherwise
will be an array of `None`. Also, a boolean array
of the same length is returned. It will be used for
assessing which environment has info data.
# Update the vector info with the updated data and mask information
vector_infos[key], vector_infos[f"_{key}"] = array, array_mask
Args:
dtype (type): data type of the info coming from the env.
Returns:
array (np.ndarray): the initialized info array.
array_mask (np.ndarray): the initialized boolean array.
"""
if dtype in [int, float, bool] or issubclass(dtype, np.number):
array = np.zeros(self.num_envs, dtype=dtype)
else:
array = np.zeros(self.num_envs, dtype=object)
array[:] = None
array_mask = np.zeros(self.num_envs, dtype=bool)
return array, array_mask
return vector_infos
def __del__(self):
"""Closes the vector environment."""
@@ -441,23 +444,23 @@ class VectorObservationWrapper(VectorWrapper):
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]:
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
obs, info = self.env.reset(seed=seed, options=options)
return self.vector_observation(obs), info
observations, infos = self.env.reset(seed=seed, options=options)
return self.observation(observations), infos
def step(
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
observation, reward, termination, truncation, info = self.env.step(actions)
observations, rewards, terminations, truncations, infos = self.env.step(actions)
return (
self.vector_observation(observation),
reward,
termination,
truncation,
self.update_final_obs(info),
self.observation(observations),
rewards,
terminations,
truncations,
infos,
)
def vector_observation(self, observation: ObsType) -> ObsType:
def observation(self, observation: ObsType) -> ObsType:
"""Defines the vector observation transformation.
Args:
@@ -468,25 +471,6 @@ class VectorObservationWrapper(VectorWrapper):
"""
raise NotImplementedError
def single_observation(self, observation: ObsType) -> ObsType:
"""Defines the single observation transformation.
Args:
observation: A single observation from the environment
Returns:
The transformed observation
"""
raise NotImplementedError
def update_final_obs(self, info: dict[str, Any]) -> dict[str, Any]:
"""Updates the `final_obs` in the info using `single_observation`."""
if "final_observation" in info:
for i, obs in enumerate(info["final_observation"]):
if obs is not None:
info["final_observation"][i] = self.single_observation(obs)
return info
class VectorActionWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the actions.
@@ -522,14 +506,14 @@ class VectorRewardWrapper(VectorWrapper):
self, actions: ActType
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]:
"""Steps through the environment returning a reward modified by :meth:`reward`."""
observation, reward, termination, truncation, info = self.env.step(actions)
return observation, self.rewards(reward), termination, truncation, info
observations, rewards, terminations, truncations, infos = self.env.step(actions)
return observations, self.rewards(rewards), terminations, truncations, infos
def rewards(self, reward: ArrayType) -> ArrayType:
def rewards(self, rewards: ArrayType) -> ArrayType:
"""Transform the reward before returning it.
Args:
reward (array): the reward to transform
rewards (array): the reward to transform
Returns:
array: the transformed reward