mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-29 17:45:07 +00:00
Adding return_info argument to reset to allow for optional info dict as a second return value (#2546)
* initial draft of optional info dict in reset function, implemented for cartpole, tests seem to be passing * merged core.py * updated return type annotation for reset function in core.py * optional metadata with return_info from reset added for all first party environments, with corresponding tests. Incomplete implementation for wrappers and vector wrappers * removed Optional type for return_info arguments * added tests for return_info to normalize wrapper and sync_vector_env * autoformatted using black * added optional reset metadata tests to several wrappers * added return_info capability to async_vector_env.py and test to verify functionality * added optional return_info test for record_video.py * removed tests for mujoco environments * autoformatted * improved test coverage for optional reset return_info * re-removed unit test envs accidentally reintroduced in merge * removed unnecessary import * changes based on code-review * small fix to core wrapper typing and autoformatted record_epsisode_stats * small change to pass flake8 style
This commit is contained in:
@@ -40,6 +40,32 @@ def test_reset_async_vector_env(shared_memory):
|
||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||
assert observations.shape == env.observation_space.shape
|
||||
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
observations = env.reset(return_info=False)
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
assert isinstance(env.observation_space, Box)
|
||||
assert isinstance(observations, np.ndarray)
|
||||
assert observations.dtype == env.observation_space.dtype
|
||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||
assert observations.shape == env.observation_space.shape
|
||||
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
observations, infos = env.reset(return_info=True)
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
assert isinstance(env.observation_space, Box)
|
||||
assert isinstance(observations, np.ndarray)
|
||||
assert observations.dtype == env.observation_space.dtype
|
||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||
assert observations.shape == env.observation_space.shape
|
||||
assert isinstance(infos, list)
|
||||
assert all([isinstance(info, dict) for info in infos])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||
@pytest.mark.parametrize("use_single_action_space", [True, False])
|
||||
|
Reference in New Issue
Block a user