Files
Gymnasium/tests/wrappers/test_vector_list_info.py
Gianluca De Cola 49d8299a1e New info API for vectorized environments #2657 (#2773)
* WIP refactor info API sync vector.

* Add missing untracked file.

* Add info strategy to reset_wait.

* Add interface and docstring.

* info with strategy pattern on async vector env.

* Add default to async vecenv.

* episode statistics for asyncvecnev.

* Add tests info strategy format.

* Add info strategy to reset_wait.

* refactor and cleanup.

* Code cleanup. Add tests.

* Add tests for video recording with new info format.

* fix test case.

* fix camelcase.

* rename enum.

* update tests, docstrings, cleanup.

* Changes brax strategy to numpy. add_strategy method in StrategyFactory. Add tests.

* fix docstring and logging format.

* Set Brax info format as default. Remove classic info format. Update tests.

* breaking the wrong loop.

* WIP: wrapper.

* Add wrapper for brax to classic info.

* WIP: wrapper with nested RecordEpisodeStatistic.

* Add tests. Refactor docstrings. Cleanup.

* cleanup.

* patch conflicts.

* rebase and conflicts.

* new pre-commit conventions.

* docstring.

* renaming.

* incorporate info_processor in vecEnv.

* renaming. Create info dict only if needed.

* remove all brax references. update docstring. Update duplicate test.

* reviews.

* pre-commit.

* reviews.

* docstring.

* cleanup blank lines.

* add support for numpy dtypes.

* docstring fix.

* formatting.

* naming.

* assert correct info from wrappers chaining. Test correct wrappers chaining. naming.

* simplify episode_statistics.

* change args orer.

* update tests.

* wip: refactor episode_statistics.

* Add test for add_vecore_episode_statistics.
2022-05-24 10:36:35 -04:00

59 lines
1.8 KiB
Python

import pytest
import gym
from gym.wrappers import RecordEpisodeStatistics, VectorListInfo
ENV_ID = "CartPole-v1"
NUM_ENVS = 3
ENV_STEPS = 50
SEED = 42
def test_usage_in_vector_env():
env = gym.make(ENV_ID)
vector_env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS)
VectorListInfo(vector_env)
with pytest.raises(AssertionError):
VectorListInfo(env)
def test_info_to_list():
env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS)
wrapped_env = VectorListInfo(env_to_wrap)
wrapped_env.action_space.seed(SEED)
_, info = wrapped_env.reset(seed=SEED, return_info=True)
assert isinstance(info, list)
assert len(info) == NUM_ENVS
for _ in range(ENV_STEPS):
action = wrapped_env.action_space.sample()
_, _, dones, list_info = wrapped_env.step(action)
for i, done in enumerate(dones):
if done:
assert "terminal_observation" in list_info[i]
else:
assert "terminal_observation" not in list_info[i]
def test_info_to_list_statistics():
env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS)
wrapped_env = VectorListInfo(RecordEpisodeStatistics(env_to_wrap))
_, info = wrapped_env.reset(seed=SEED, return_info=True)
wrapped_env.action_space.seed(SEED)
assert isinstance(info, list)
assert len(info) == NUM_ENVS
for _ in range(ENV_STEPS):
action = wrapped_env.action_space.sample()
_, _, dones, list_info = wrapped_env.step(action)
for i, done in enumerate(dones):
if done:
assert "episode" in list_info[i]
for stats in ["r", "l", "t"]:
assert stats in list_info[i]["episode"]
assert isinstance(list_info[i]["episode"][stats], float)
else:
assert "episode" not in list_info[i]