New info API for vectorized environments #2657 (#2773)

* WIP refactor info API sync vector.

* Add missing untracked file.

* Add info strategy to reset_wait.

* Add interface and docstring.

* info with strategy pattern on async vector env.

* Add default to async vecenv.

* episode statistics for asyncvecnev.

* Add tests info strategy format.

* Add info strategy to reset_wait.

* refactor and cleanup.

* Code cleanup. Add tests.

* Add tests for video recording with new info format.

* fix test case.

* fix camelcase.

* rename enum.

* update tests, docstrings, cleanup.

* Changes brax strategy to numpy. add_strategy method in StrategyFactory. Add tests.

* fix docstring and logging format.

* Set Brax info format as default. Remove classic info format. Update tests.

* breaking the wrong loop.

* WIP: wrapper.

* Add wrapper for brax to classic info.

* WIP: wrapper with nested RecordEpisodeStatistic.

* Add tests. Refactor docstrings. Cleanup.

* cleanup.

* patch conflicts.

* rebase and conflicts.

* new pre-commit conventions.

* docstring.

* renaming.

* incorporate info_processor in vecEnv.

* renaming. Create info dict only if needed.

* remove all brax references. update docstring. Update duplicate test.

* reviews.

* pre-commit.

* reviews.

* docstring.

* cleanup blank lines.

* add support for numpy dtypes.

* docstring fix.

* formatting.

* naming.

* assert correct info from wrappers chaining. Test correct wrappers chaining. naming.

* simplify episode_statistics.

* change args orer.

* update tests.

* wip: refactor episode_statistics.

* Add test for add_vecore_episode_statistics.
This commit is contained in:
Gianluca De Cola
2022-05-24 16:36:35 +02:00
committed by GitHub
parent bbf8f5a467
commit 49d8299a1e
13 changed files with 428 additions and 42 deletions

View File

@@ -271,8 +271,10 @@ class AsyncVectorEnv(VectorEnv):
self._state = AsyncState.DEFAULT
if return_info:
results, infos = zip(*results)
infos = list(infos)
infos = {}
results, info_data = zip(*results)
for i, info in enumerate(info_data):
infos = self._add_info(infos, info, i)
if not self.shared_memory:
self.observations = concatenate(
@@ -344,10 +346,20 @@ class AsyncVectorEnv(VectorEnv):
f"The call to `step_wait` has timed out after {timeout} second(s)."
)
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
observations_list, rewards, dones, infos = [], [], [], {}
successes = []
for i, pipe in enumerate(self.parent_pipes):
result, success = pipe.recv()
obs, rew, done, info = result
successes.append(success)
observations_list.append(obs)
rewards.append(rew)
dones.append(done)
infos = self._add_info(infos, info, i)
self._raise_if_errors(successes)
self._state = AsyncState.DEFAULT
observations_list, rewards, dones, infos = zip(*results)
if not self.shared_memory:
self.observations = concatenate(