New info API for vectorized environments #2657 (#2773)

* WIP refactor info API sync vector.

* Add missing untracked file.

* Add info strategy to reset_wait.

* Add interface and docstring.

* info with strategy pattern on async vector env.

* Add default to async vecenv.

* episode statistics for asyncvecnev.

* Add tests info strategy format.

* Add info strategy to reset_wait.

* refactor and cleanup.

* Code cleanup. Add tests.

* Add tests for video recording with new info format.

* fix test case.

* fix camelcase.

* rename enum.

* update tests, docstrings, cleanup.

* Changes brax strategy to numpy. add_strategy method in StrategyFactory. Add tests.

* fix docstring and logging format.

* Set Brax info format as default. Remove classic info format. Update tests.

* breaking the wrong loop.

* WIP: wrapper.

* Add wrapper for brax to classic info.

* WIP: wrapper with nested RecordEpisodeStatistic.

* Add tests. Refactor docstrings. Cleanup.

* cleanup.

* patch conflicts.

* rebase and conflicts.

* new pre-commit conventions.

* docstring.

* renaming.

* incorporate info_processor in vecEnv.

* renaming. Create info dict only if needed.

* remove all brax references. update docstring. Update duplicate test.

* reviews.

* pre-commit.

* reviews.

* docstring.

* cleanup blank lines.

* add support for numpy dtypes.

* docstring fix.

* formatting.

* naming.

* assert correct info from wrappers chaining. Test correct wrappers chaining. naming.

* simplify episode_statistics.

* change args orer.

* update tests.

* wip: refactor episode_statistics.

* Add test for add_vecore_episode_statistics.
This commit is contained in:
Gianluca De Cola
2022-05-24 16:36:35 +02:00
committed by GitHub
parent bbf8f5a467
commit 49d8299a1e
13 changed files with 428 additions and 42 deletions

View File

@@ -108,8 +108,8 @@ class SyncVectorEnv(VectorEnv):
self._dones[:] = False
observations = []
data_list = []
for env, single_seed in zip(self.envs, seed):
infos = {}
for i, (env, single_seed) in enumerate(zip(self.envs, seed)):
kwargs = {}
if single_seed is not None:
@@ -123,9 +123,9 @@ class SyncVectorEnv(VectorEnv):
observation = env.reset(**kwargs)
observations.append(observation)
else:
observation, data = env.reset(**kwargs)
observation, info = env.reset(**kwargs)
observations.append(observation)
data_list.append(data)
infos = self._add_info(infos, info, i)
self.observations = concatenate(
self.single_observation_space, observations, self.observations
@@ -135,7 +135,7 @@ class SyncVectorEnv(VectorEnv):
else:
return (
deepcopy(self.observations) if self.copy else self.observations
), data_list
), infos
def step_async(self, actions):
"""Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version."""
@@ -147,14 +147,14 @@ class SyncVectorEnv(VectorEnv):
Returns:
The batched environment step results
"""
observations, infos = [], []
observations, infos = [], {}
for i, (env, action) in enumerate(zip(self.envs, self._actions)):
observation, self._rewards[i], self._dones[i], info = env.step(action)
if self._dones[i]:
info["terminal_observation"] = observation
observation = env.reset()
observations.append(observation)
infos.append(info)
infos = self._add_info(infos, info, i)
self.observations = concatenate(
self.single_observation_space, observations, self.observations
)