From 27f8e8505184eb312a0841e28a7860fb8dbc2566 Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Tue, 7 Nov 2023 13:27:25 +0000 Subject: [PATCH] Merge v1.0.0 (#682) Co-authored-by: Kallinteris Andreas <30759571+Kallinteris-Andreas@users.noreply.github.com> Co-authored-by: Jet <38184875+jjshoots@users.noreply.github.com> Co-authored-by: Omar Younis <42100908+younik@users.noreply.github.com> --- .github/workflows/build-docs.yml | 96 +-- .github/workflows/build.yml | 2 +- .github/workflows/docs-manual-versioning.yml | 144 ++-- .github/workflows/docs-versioning.yml | 115 +-- .pre-commit-config.yaml | 2 +- docs/_scripts/gen_wrapper_table.py | 73 ++ docs/api/env.md | 44 +- docs/api/experimental.md | 157 ---- docs/api/experimental/functional.md | 37 - docs/api/experimental/vector.md | 39 - docs/api/experimental/vector_utils.md | 29 - docs/api/experimental/vector_wrappers.md | 53 -- docs/api/experimental/wrappers.md | 65 -- docs/api/functional.md | 34 + docs/api/registry.md | 7 +- docs/api/spaces.md | 81 +- docs/api/spaces/composite.md | 29 +- docs/api/spaces/fundamental.md | 38 +- docs/api/utils.md | 34 +- docs/api/vector.md | 65 +- docs/api/vector/async_vector_env.md | 13 + docs/api/vector/sync_vector_env.md | 13 + .../vector_utils.md => vector/utils.md} | 17 +- docs/api/vector/wrappers.md | 26 + docs/api/wrappers.md | 125 +-- docs/api/wrappers/action_wrappers.md | 4 +- docs/api/wrappers/misc_wrappers.md | 31 +- docs/api/wrappers/observation_wrappers.md | 15 +- docs/api/wrappers/reward_wrappers.md | 8 +- docs/api/wrappers/table.md | 102 +++ docs/api/wrappers/vector_wrappers.md | 19 + docs/conf.py | 2 +- docs/content/basic_usage.md | 134 ---- docs/content/migration-guide.md | 122 --- docs/index.md | 28 +- docs/introduction/basic_usage.md | 172 +++++ .../gym_compatibility.md | 8 +- docs/introduction/migration-guide.md | 103 +++ gymnasium/__init__.py | 7 +- gymnasium/core.py | 174 ++--- gymnasium/envs/__init__.py | 13 +- gymnasium/envs/box2d/lunar_lander.py | 2 +- gymnasium/envs/classic_control/cartpole.py | 38 +- .../functional_jax_env.py | 10 +- gymnasium/envs/phys2d/cartpole.py | 7 +- gymnasium/envs/phys2d/pendulum.py | 7 +- gymnasium/envs/registration.py | 267 +++---- gymnasium/envs/tabular/blackjack.py | 4 +- gymnasium/envs/tabular/cliffwalking.py | 4 +- gymnasium/experimental/__init__.py | 25 - gymnasium/experimental/vector/__init__.py | 23 - .../experimental/vector/async_vector_env.py | 685 ----------------- .../experimental/vector/sync_vector_env.py | 229 ------ .../experimental/vector/utils/__init__.py | 30 - gymnasium/experimental/vector/utils/misc.py | 61 -- .../vector/utils/shared_memory.py | 255 ------- gymnasium/experimental/vector/vector_env.py | 486 ------------ gymnasium/experimental/wrappers/__init__.py | 164 ---- .../wrappers/atari_preprocessing.py | 206 ----- gymnasium/experimental/wrappers/common.py | 315 -------- .../wrappers/lambda_observation.py | 620 --------------- .../experimental/wrappers/lambda_reward.py | 102 --- .../experimental/wrappers/vector/__init__.py | 146 ---- .../wrappers/vector/dict_info_to_list.py | 86 --- .../wrappers/vector/vectorize_action.py | 143 ---- .../wrappers/vector/vectorize_observation.py | 222 ------ .../wrappers/vector/vectorize_reward.py | 78 -- gymnasium/{experimental => }/functional.py | 31 +- gymnasium/spaces/box.py | 2 +- gymnasium/spaces/dict.py | 2 +- gymnasium/spaces/discrete.py | 4 +- gymnasium/spaces/graph.py | 12 +- gymnasium/spaces/multi_binary.py | 4 +- gymnasium/spaces/multi_discrete.py | 9 +- gymnasium/spaces/sequence.py | 7 +- gymnasium/spaces/text.py | 8 +- gymnasium/spaces/tuple.py | 4 +- gymnasium/spaces/utils.py | 13 +- gymnasium/utils/env_checker.py | 29 +- gymnasium/utils/passive_env_checker.py | 2 + gymnasium/utils/play.py | 74 +- gymnasium/utils/save_video.py | 4 +- gymnasium/utils/seeding.py | 17 +- gymnasium/utils/step_api_compatibility.py | 55 +- gymnasium/vector/__init__.py | 90 +-- gymnasium/vector/async_vector_env.py | 439 ++++++----- gymnasium/vector/sync_vector_env.py | 247 +++--- gymnasium/vector/utils/__init__.py | 21 +- gymnasium/vector/utils/misc.py | 2 +- gymnasium/vector/utils/numpy_utils.py | 146 ---- gymnasium/vector/utils/shared_memory.py | 136 +++- .../vector/utils/space_utils.py | 8 +- gymnasium/vector/utils/spaces.py | 215 ------ gymnasium/vector/vector_env.py | 529 ++++++++----- gymnasium/wrappers/README.md | 18 - gymnasium/wrappers/__init__.py | 163 +++- gymnasium/wrappers/atari_preprocessing.py | 40 +- gymnasium/wrappers/autoreset.py | 86 --- gymnasium/wrappers/clip_action.py | 43 -- gymnasium/wrappers/common.py | 536 +++++++++++++ gymnasium/wrappers/compatibility.py | 129 ---- gymnasium/wrappers/env_checker.py | 95 --- gymnasium/wrappers/filter_observation.py | 92 --- gymnasium/wrappers/flatten_observation.py | 43 -- gymnasium/wrappers/frame_stack.py | 196 ----- gymnasium/wrappers/gray_scale_observation.py | 68 -- gymnasium/wrappers/human_rendering.py | 142 ---- .../wrappers/jax_to_numpy.py | 28 +- .../wrappers/jax_to_torch.py | 31 +- gymnasium/wrappers/monitoring/__init__.py | 1 - .../wrappers/monitoring/video_recorder.py | 178 ----- gymnasium/wrappers/normalize.py | 155 ---- .../wrappers/numpy_to_torch.py | 32 +- gymnasium/wrappers/order_enforcing.py | 89 --- gymnasium/wrappers/pixel_observation.py | 215 ------ .../wrappers/record_episode_statistics.py | 131 ---- gymnasium/wrappers/record_video.py | 228 ------ gymnasium/wrappers/render_collection.py | 62 -- .../{experimental => }/wrappers/rendering.py | 211 ++++-- gymnasium/wrappers/rescale_action.py | 89 --- gymnasium/wrappers/resize_observation.py | 83 -- .../wrappers/stateful_action.py | 26 +- .../wrappers/stateful_observation.py | 203 +++-- .../wrappers/stateful_reward.py | 67 +- gymnasium/wrappers/step_api_compatibility.py | 56 -- gymnasium/wrappers/time_aware_observation.py | 79 -- gymnasium/wrappers/time_limit.py | 89 --- .../transform_action.py} | 72 +- gymnasium/wrappers/transform_observation.py | 712 +++++++++++++++++- gymnasium/wrappers/transform_reward.py | 114 ++- .../{experimental => }/wrappers/utils.py | 6 +- gymnasium/wrappers/vector/__init__.py | 106 +++ .../vector/common.py} | 50 +- .../wrappers/vector/dict_info_to_list.py | 153 ++++ .../wrappers/vector/jax_to_numpy.py | 17 +- .../wrappers/vector/jax_to_torch.py | 19 +- .../wrappers/vector/numpy_to_torch.py | 39 +- .../wrappers/vector/stateful_observation.py | 111 +++ gymnasium/wrappers/vector/stateful_reward.py | 115 +++ gymnasium/wrappers/vector/vectorize_action.py | 253 +++++++ .../wrappers/vector/vectorize_observation.py | 394 ++++++++++ gymnasium/wrappers/vector/vectorize_reward.py | 163 ++++ gymnasium/wrappers/vector_list_info.py | 126 ---- pyproject.toml | 1 + tests/envs/functional/test_core.py | 2 +- tests/envs/mujoco/__init__.py | 0 tests/envs/registration/test_env_spec.py | 13 +- tests/envs/registration/test_make.py | 104 +-- tests/envs/registration/test_make_vec.py | 106 ++- tests/envs/test_action_dim_check.py | 4 +- tests/envs/test_compatibility.py | 190 ----- tests/envs/test_env_implementation.py | 6 +- tests/experimental/__init__.py | 1 - tests/experimental/vector/__init__.py | 1 - .../vector/test_async_vector_env.py | 329 -------- .../vector/test_sync_vector_env.py | 187 ----- tests/experimental/vector/test_vector_env.py | 126 ---- .../vector/test_vector_env_info.py | 62 -- tests/experimental/vector/utils/__init__.py | 1 - tests/experimental/wrappers/__init__.py | 1 - .../wrappers/test_atari_preprocessing.py | 1 - tests/experimental/wrappers/test_autoreset.py | 1 - .../experimental/wrappers/test_clip_action.py | 25 - .../experimental/wrappers/test_clip_reward.py | 67 -- .../wrappers/test_filter_observation.py | 45 -- .../wrappers/test_flatten_observation.py | 25 - .../wrappers/test_grayscale_observation.py | 38 - .../wrappers/test_human_rendering.py | 1 - .../wrappers/test_import_wrappers.py | 58 -- .../wrappers/test_lambda_rewards.py | 50 -- .../wrappers/test_normalize_reward.py | 55 -- .../wrappers/test_numpy_to_torch.py | 1 - .../wrappers/test_order_enforcing.py | 1 - .../wrappers/test_passive_env_checker.py | 1 - .../wrappers/test_pixel_observation.py | 1 - .../test_record_episode_statistics.py | 1 - .../wrappers/test_record_video.py | 168 ----- .../wrappers/test_render_collection.py | 1 - .../wrappers/test_rescale_action.py | 38 - .../wrappers/test_resize_observation.py | 57 -- .../wrappers/test_time_aware_observation.py | 75 -- tests/experimental/wrappers/utils.py | 77 -- .../experimental/wrappers/vector/__init__.py | 1 - .../{experimental => }/functional/__init__.py | 0 .../functional/test_func_jax_env.py | 0 .../functional/test_functional.py | 2 +- .../functional/test_jax_blackjack.py | 0 .../functional/test_jax_cliffwalking.py | 0 tests/test_core.py | 248 ++---- tests/utils/test_save_video.py | 7 +- tests/vector/__init__.py | 1 + tests/vector/test_async_vector_env.py | 66 +- tests/vector/test_numpy_utils.py | 142 ---- tests/vector/test_shared_memory.py | 189 ----- tests/vector/test_spaces.py | 205 ----- tests/vector/test_sync_vector_env.py | 67 +- tests/vector/test_vector_env.py | 47 +- tests/vector/test_vector_env_info.py | 19 +- tests/vector/test_vector_env_wrapper.py | 31 - tests/vector/test_vector_make.py | 83 -- .../vector/test_vector_wrapper.py | 37 +- .../vector/testing_utils.py | 2 +- tests/vector/utils.py | 141 ---- tests/vector/utils/__init__.py | 1 + .../vector/utils/test_shared_memory.py | 6 +- .../vector/utils/test_space_utils.py | 11 +- .../{experimental => }/vector/utils/utils.py | 0 tests/wrappers/__init__.py | 1 + tests/wrappers/test_atari_preprocessing.py | 28 +- tests/wrappers/test_autoreset.py | 107 +-- tests/wrappers/test_clip_action.py | 37 +- tests/wrappers/test_clip_reward.py | 42 ++ .../wrappers/test_delay_observation.py | 20 +- .../wrappers/test_dtype_observation.py | 11 +- tests/wrappers/test_filter_observation.py | 140 ++-- tests/wrappers/test_flatten.py | 98 --- tests/wrappers/test_flatten_observation.py | 36 +- tests/wrappers/test_frame_stack.py | 53 -- .../wrappers/test_frame_stack_observation.py | 22 +- tests/wrappers/test_gray_scale_observation.py | 52 +- tests/wrappers/test_human_rendering.py | 1 + tests/wrappers/test_import_wrappers.py | 51 ++ .../wrappers/test_jax_to_numpy.py | 14 +- .../wrappers/test_jax_to_torch.py | 8 +- .../wrappers/test_lambda_action.py | 8 +- .../wrappers/test_lambda_observation.py | 12 +- tests/wrappers/test_lambda_reward.py | 15 + .../wrappers/test_max_and_skip_observation.py | 16 +- tests/wrappers/test_nested_dict.py | 120 --- tests/wrappers/test_normalize.py | 125 --- .../wrappers/test_normalize_observation.py | 9 +- tests/wrappers/test_normalize_reward.py | 83 ++ tests/wrappers/test_numpy_to_torch.py | 110 +++ tests/wrappers/test_order_enforcing.py | 13 +- tests/wrappers/test_passive_env_checker.py | 4 +- tests/wrappers/test_pixel_observation.py | 125 --- .../test_record_episode_statistics.py | 58 +- tests/wrappers/test_record_video.py | 136 +++- tests/wrappers/test_render_observation.py | 96 +++ tests/wrappers/test_rescale_action.py | 53 +- .../wrappers/test_rescale_observation.py | 12 +- .../wrappers/test_reshape_observation.py | 10 +- tests/wrappers/test_resize_observation.py | 83 +- tests/wrappers/test_step_compatibility.py | 98 --- .../wrappers/test_sticky_action.py | 10 +- tests/wrappers/test_time_aware_observation.py | 87 ++- tests/wrappers/test_time_limit.py | 2 + tests/wrappers/test_transform_observation.py | 36 - tests/wrappers/test_transform_reward.py | 63 -- tests/wrappers/test_video_recorder.py | 92 --- tests/wrappers/utils.py | 82 ++ tests/wrappers/vector/__init__.py | 1 + .../test_dict_info_to_list.py} | 18 +- .../vector/test_normalize_observation.py | 68 ++ .../wrappers/vector/test_normalize_reward.py | 70 ++ .../wrappers/vector/test_vector_wrappers.py | 50 +- 256 files changed, 7051 insertions(+), 13421 deletions(-) create mode 100644 docs/_scripts/gen_wrapper_table.py delete mode 100644 docs/api/experimental.md delete mode 100644 docs/api/experimental/vector.md delete mode 100644 docs/api/experimental/vector_utils.md delete mode 100644 docs/api/experimental/vector_wrappers.md delete mode 100644 docs/api/experimental/wrappers.md create mode 100644 docs/api/functional.md create mode 100644 docs/api/vector/async_vector_env.md create mode 100644 docs/api/vector/sync_vector_env.md rename docs/api/{spaces/vector_utils.md => vector/utils.md} (66%) create mode 100644 docs/api/vector/wrappers.md create mode 100644 docs/api/wrappers/table.md create mode 100644 docs/api/wrappers/vector_wrappers.md delete mode 100644 docs/content/basic_usage.md delete mode 100644 docs/content/migration-guide.md create mode 100644 docs/introduction/basic_usage.md rename docs/{content => introduction}/gym_compatibility.md (65%) create mode 100644 docs/introduction/migration-guide.md rename gymnasium/{experimental => envs}/functional_jax_env.py (96%) delete mode 100644 gymnasium/experimental/__init__.py delete mode 100644 gymnasium/experimental/vector/__init__.py delete mode 100644 gymnasium/experimental/vector/async_vector_env.py delete mode 100644 gymnasium/experimental/vector/sync_vector_env.py delete mode 100644 gymnasium/experimental/vector/utils/__init__.py delete mode 100644 gymnasium/experimental/vector/utils/misc.py delete mode 100644 gymnasium/experimental/vector/utils/shared_memory.py delete mode 100644 gymnasium/experimental/vector/vector_env.py delete mode 100644 gymnasium/experimental/wrappers/__init__.py delete mode 100644 gymnasium/experimental/wrappers/atari_preprocessing.py delete mode 100644 gymnasium/experimental/wrappers/common.py delete mode 100644 gymnasium/experimental/wrappers/lambda_observation.py delete mode 100644 gymnasium/experimental/wrappers/lambda_reward.py delete mode 100644 gymnasium/experimental/wrappers/vector/__init__.py delete mode 100644 gymnasium/experimental/wrappers/vector/dict_info_to_list.py delete mode 100644 gymnasium/experimental/wrappers/vector/vectorize_action.py delete mode 100644 gymnasium/experimental/wrappers/vector/vectorize_observation.py delete mode 100644 gymnasium/experimental/wrappers/vector/vectorize_reward.py rename gymnasium/{experimental => }/functional.py (75%) delete mode 100644 gymnasium/vector/utils/numpy_utils.py rename gymnasium/{experimental => }/vector/utils/space_utils.py (98%) delete mode 100644 gymnasium/vector/utils/spaces.py delete mode 100644 gymnasium/wrappers/README.md delete mode 100644 gymnasium/wrappers/autoreset.py delete mode 100644 gymnasium/wrappers/clip_action.py create mode 100644 gymnasium/wrappers/common.py delete mode 100644 gymnasium/wrappers/compatibility.py delete mode 100644 gymnasium/wrappers/env_checker.py delete mode 100644 gymnasium/wrappers/filter_observation.py delete mode 100644 gymnasium/wrappers/flatten_observation.py delete mode 100644 gymnasium/wrappers/frame_stack.py delete mode 100644 gymnasium/wrappers/gray_scale_observation.py delete mode 100644 gymnasium/wrappers/human_rendering.py rename gymnasium/{experimental => }/wrappers/jax_to_numpy.py (78%) rename gymnasium/{experimental => }/wrappers/jax_to_torch.py (79%) delete mode 100644 gymnasium/wrappers/monitoring/__init__.py delete mode 100644 gymnasium/wrappers/monitoring/video_recorder.py delete mode 100644 gymnasium/wrappers/normalize.py rename gymnasium/{experimental => }/wrappers/numpy_to_torch.py (82%) delete mode 100644 gymnasium/wrappers/order_enforcing.py delete mode 100644 gymnasium/wrappers/pixel_observation.py delete mode 100644 gymnasium/wrappers/record_episode_statistics.py delete mode 100644 gymnasium/wrappers/record_video.py delete mode 100644 gymnasium/wrappers/render_collection.py rename gymnasium/{experimental => }/wrappers/rendering.py (72%) delete mode 100644 gymnasium/wrappers/rescale_action.py delete mode 100644 gymnasium/wrappers/resize_observation.py rename gymnasium/{experimental => }/wrappers/stateful_action.py (64%) rename gymnasium/{experimental => }/wrappers/stateful_observation.py (73%) rename gymnasium/{experimental => }/wrappers/stateful_reward.py (52%) delete mode 100644 gymnasium/wrappers/step_api_compatibility.py delete mode 100644 gymnasium/wrappers/time_aware_observation.py delete mode 100644 gymnasium/wrappers/time_limit.py rename gymnasium/{experimental/wrappers/lambda_action.py => wrappers/transform_action.py} (69%) rename gymnasium/{experimental => }/wrappers/utils.py (96%) create mode 100644 gymnasium/wrappers/vector/__init__.py rename gymnasium/{experimental/wrappers/vector/record_episode_statistics.py => wrappers/vector/common.py} (73%) create mode 100644 gymnasium/wrappers/vector/dict_info_to_list.py rename gymnasium/{experimental => }/wrappers/vector/jax_to_numpy.py (78%) rename gymnasium/{experimental => }/wrappers/vector/jax_to_torch.py (80%) rename gymnasium/{experimental => }/wrappers/vector/numpy_to_torch.py (66%) create mode 100644 gymnasium/wrappers/vector/stateful_observation.py create mode 100644 gymnasium/wrappers/vector/stateful_reward.py create mode 100644 gymnasium/wrappers/vector/vectorize_action.py create mode 100644 gymnasium/wrappers/vector/vectorize_observation.py create mode 100644 gymnasium/wrappers/vector/vectorize_reward.py delete mode 100644 gymnasium/wrappers/vector_list_info.py create mode 100644 tests/envs/mujoco/__init__.py delete mode 100644 tests/envs/test_compatibility.py delete mode 100644 tests/experimental/__init__.py delete mode 100644 tests/experimental/vector/__init__.py delete mode 100644 tests/experimental/vector/test_async_vector_env.py delete mode 100644 tests/experimental/vector/test_sync_vector_env.py delete mode 100644 tests/experimental/vector/test_vector_env.py delete mode 100644 tests/experimental/vector/test_vector_env_info.py delete mode 100644 tests/experimental/vector/utils/__init__.py delete mode 100644 tests/experimental/wrappers/__init__.py delete mode 100644 tests/experimental/wrappers/test_atari_preprocessing.py delete mode 100644 tests/experimental/wrappers/test_autoreset.py delete mode 100644 tests/experimental/wrappers/test_clip_action.py delete mode 100644 tests/experimental/wrappers/test_clip_reward.py delete mode 100644 tests/experimental/wrappers/test_filter_observation.py delete mode 100644 tests/experimental/wrappers/test_flatten_observation.py delete mode 100644 tests/experimental/wrappers/test_grayscale_observation.py delete mode 100644 tests/experimental/wrappers/test_human_rendering.py delete mode 100644 tests/experimental/wrappers/test_import_wrappers.py delete mode 100644 tests/experimental/wrappers/test_lambda_rewards.py delete mode 100644 tests/experimental/wrappers/test_normalize_reward.py delete mode 100644 tests/experimental/wrappers/test_numpy_to_torch.py delete mode 100644 tests/experimental/wrappers/test_order_enforcing.py delete mode 100644 tests/experimental/wrappers/test_passive_env_checker.py delete mode 100644 tests/experimental/wrappers/test_pixel_observation.py delete mode 100644 tests/experimental/wrappers/test_record_episode_statistics.py delete mode 100644 tests/experimental/wrappers/test_record_video.py delete mode 100644 tests/experimental/wrappers/test_render_collection.py delete mode 100644 tests/experimental/wrappers/test_rescale_action.py delete mode 100644 tests/experimental/wrappers/test_resize_observation.py delete mode 100644 tests/experimental/wrappers/test_time_aware_observation.py delete mode 100644 tests/experimental/wrappers/utils.py delete mode 100644 tests/experimental/wrappers/vector/__init__.py rename tests/{experimental => }/functional/__init__.py (100%) rename tests/{experimental => }/functional/test_func_jax_env.py (100%) rename tests/{experimental => }/functional/test_functional.py (97%) rename tests/{experimental => }/functional/test_jax_blackjack.py (100%) rename tests/{experimental => }/functional/test_jax_cliffwalking.py (100%) delete mode 100644 tests/vector/test_numpy_utils.py delete mode 100644 tests/vector/test_shared_memory.py delete mode 100644 tests/vector/test_spaces.py delete mode 100644 tests/vector/test_vector_env_wrapper.py delete mode 100644 tests/vector/test_vector_make.py rename tests/{experimental => }/vector/test_vector_wrapper.py (51%) rename tests/{experimental => }/vector/testing_utils.py (98%) delete mode 100644 tests/vector/utils.py create mode 100644 tests/vector/utils/__init__.py rename tests/{experimental => }/vector/utils/test_shared_memory.py (96%) rename tests/{experimental => }/vector/utils/test_space_utils.py (96%) rename tests/{experimental => }/vector/utils/utils.py (100%) create mode 100644 tests/wrappers/test_clip_reward.py rename tests/{experimental => }/wrappers/test_delay_observation.py (81%) rename tests/{experimental => }/wrappers/test_dtype_observation.py (65%) delete mode 100644 tests/wrappers/test_flatten.py delete mode 100644 tests/wrappers/test_frame_stack.py rename tests/{experimental => }/wrappers/test_frame_stack_observation.py (81%) create mode 100644 tests/wrappers/test_import_wrappers.py rename tests/{experimental => }/wrappers/test_jax_to_numpy.py (93%) rename tests/{experimental => }/wrappers/test_jax_to_torch.py (95%) rename tests/{experimental => }/wrappers/test_lambda_action.py (68%) rename tests/{experimental => }/wrappers/test_lambda_observation.py (66%) create mode 100644 tests/wrappers/test_lambda_reward.py rename tests/{experimental => }/wrappers/test_max_and_skip_observation.py (64%) delete mode 100644 tests/wrappers/test_nested_dict.py delete mode 100644 tests/wrappers/test_normalize.py rename tests/{experimental => }/wrappers/test_normalize_observation.py (81%) create mode 100644 tests/wrappers/test_normalize_reward.py create mode 100644 tests/wrappers/test_numpy_to_torch.py delete mode 100644 tests/wrappers/test_pixel_observation.py create mode 100644 tests/wrappers/test_render_observation.py rename tests/{experimental => }/wrappers/test_rescale_observation.py (85%) rename tests/{experimental => }/wrappers/test_reshape_observation.py (76%) delete mode 100644 tests/wrappers/test_step_compatibility.py rename tests/{experimental => }/wrappers/test_sticky_action.py (81%) delete mode 100644 tests/wrappers/test_transform_observation.py delete mode 100644 tests/wrappers/test_transform_reward.py delete mode 100644 tests/wrappers/test_video_recorder.py create mode 100644 tests/wrappers/vector/__init__.py rename tests/wrappers/{test_vector_list_info.py => vector/test_dict_info_to_list.py} (74%) create mode 100644 tests/wrappers/vector/test_normalize_observation.py create mode 100644 tests/wrappers/vector/test_normalize_reward.py rename tests/{experimental => }/wrappers/vector/test_vector_wrappers.py (63%) diff --git a/.github/workflows/build-docs.yml b/.github/workflows/build-docs.yml index 6fbe46227..0107438a5 100644 --- a/.github/workflows/build-docs.yml +++ b/.github/workflows/build-docs.yml @@ -1,46 +1,50 @@ -name: Build main branch documentation website -on: - push: - branches: [main] -permissions: - contents: write -jobs: - docs: - name: Generate Website - runs-on: ubuntu-latest - env: - SPHINX_GITHUB_CHANGELOG_TOKEN: ${{ secrets.GITHUB_TOKEN }} - steps: - - uses: actions/checkout@v3 - - - uses: actions/setup-python@v4 - with: - python-version: '3.9' - - - name: Install dependencies - run: pip install -r docs/requirements.txt - - - name: Install Gymnasium - run: pip install .[box2d,mujoco,jax] torch - - - name: Build Envs Docs - run: python docs/_scripts/gen_mds.py && python docs/_scripts/gen_envs_display.py - - - name: Build - run: sphinx-build -b dirhtml -v docs _build - - - name: Move 404 - run: mv _build/404/index.html _build/404.html - - - name: Update 404 links - run: python docs/_scripts/move_404.py _build/404.html - - - name: Remove .doctrees - run: rm -r _build/.doctrees - - - name: Upload to GitHub Pages - uses: JamesIves/github-pages-deploy-action@v4 - with: - folder: _build - target-folder: main - clean: false +name: Build main branch documentation website + +on: + push: + branches: [main] + +permissions: + contents: write + +jobs: + docs: + name: Generate Website + runs-on: ubuntu-latest + env: + SPHINX_GITHUB_CHANGELOG_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + steps: + - uses: actions/checkout@v3 + + - uses: actions/setup-python@v4 + with: + python-version: '3.9' + + - name: Install dependencies + run: pip install -r docs/requirements.txt + + - name: Install Gymnasium + run: pip install .[box2d,mujoco,jax] torch + + - name: Build Envs Docs + run: python docs/_scripts/gen_mds.py && python docs/_scripts/gen_envs_display.py + + - name: Build + run: sphinx-build -b dirhtml -v docs _build + + - name: Move 404 + run: mv _build/404/index.html _build/404.html + + - name: Update 404 links + run: python docs/_scripts/move_404.py _build/404.html + + - name: Remove .doctrees + run: rm -r _build/.doctrees + + - name: Upload to GitHub Pages + uses: JamesIves/github-pages-deploy-action@v4 + with: + folder: _build + target-folder: main + clean: false diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ffc47700c..6087a696a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -32,4 +32,4 @@ jobs: --tag gymnasium-necessary-docker . - name: Run tests run: | - docker run gymnasium-necessary-docker pytest tests/test_core.py tests/envs/test_compatibility.py tests/envs/test_envs.py tests/spaces + docker run gymnasium-necessary-docker pytest tests/test_core.py tests/envs/test_envs.py tests/spaces diff --git a/.github/workflows/docs-manual-versioning.yml b/.github/workflows/docs-manual-versioning.yml index 9b816f7c5..f3f738649 100644 --- a/.github/workflows/docs-manual-versioning.yml +++ b/.github/workflows/docs-manual-versioning.yml @@ -1,71 +1,73 @@ -name: Manual Docs Versioning -on: - workflow_dispatch: - inputs: - version: - description: 'Documentation version to create' - required: true - commit: - description: 'Commit used to build the Documentation version' - required: false - latest: - description: 'Latest version' - type: boolean - -permissions: - contents: write -jobs: - docs: - name: Generate Website for new version - runs-on: ubuntu-latest - env: - SPHINX_GITHUB_CHANGELOG_TOKEN: ${{ secrets.GITHUB_TOKEN }} - steps: - - uses: actions/checkout@v3 - if: inputs.commit == '' - - - uses: actions/checkout@v3 - if: inputs.commit != '' - with: - ref: ${{ inputs.commit }} - - - uses: actions/setup-python@v4 - with: - python-version: '3.9' - - - name: Install dependencies - run: pip install -r docs/requirements.txt - - - name: Install Gymnasium - run: pip install .[box2d,mujoco,jax] torch - - - name: Build Envs Docs - run: python docs/_scripts/gen_mds.py && python docs/_scripts/gen_envs_display.py - - - name: Build - run: sphinx-build -b dirhtml -v docs _build - - - name: Move 404 - run: mv _build/404/index.html _build/404.html - - - name: Update 404 links - run: python docs/_scripts/move_404.py _build/404.html - - - name: Remove .doctrees - run: rm -r _build/.doctrees - - - name: Upload to GitHub Pages - uses: JamesIves/github-pages-deploy-action@v4 - with: - folder: _build - target-folder: ${{ inputs.version }} - clean: false - - - name: Upload to GitHub Pages - uses: JamesIves/github-pages-deploy-action@v4 - if: inputs.latest - with: - folder: _build - clean-exclude: | - *.*.*/ - main +name: Manual Docs Versioning + +on: + workflow_dispatch: + inputs: + version: + description: 'Documentation version to create' + required: true + commit: + description: 'Commit used to build the Documentation version' + required: false + latest: + description: 'Latest version' + type: boolean + +permissions: + contents: write + +jobs: + docs: + name: Generate Website for new version + runs-on: ubuntu-latest + env: + SPHINX_GITHUB_CHANGELOG_TOKEN: ${{ secrets.GITHUB_TOKEN }} + steps: + - uses: actions/checkout@v3 + if: inputs.commit == '' + + - uses: actions/checkout@v3 + if: inputs.commit != '' + with: + ref: ${{ inputs.commit }} + + - uses: actions/setup-python@v4 + with: + python-version: '3.9' + + - name: Install dependencies + run: pip install -r docs/requirements.txt + + - name: Install Gymnasium + run: pip install .[box2d,mujoco,jax] torch + + - name: Build Envs Docs + run: python docs/_scripts/gen_mds.py && python docs/_scripts/gen_envs_display.py + + - name: Build + run: sphinx-build -b dirhtml -v docs _build + + - name: Move 404 + run: mv _build/404/index.html _build/404.html + + - name: Update 404 links + run: python docs/_scripts/move_404.py _build/404.html + + - name: Remove .doctrees + run: rm -r _build/.doctrees + + - name: Upload to GitHub Pages + uses: JamesIves/github-pages-deploy-action@v4 + with: + folder: _build + target-folder: ${{ inputs.version }} + clean: false + + - name: Upload to GitHub Pages + uses: JamesIves/github-pages-deploy-action@v4 + if: inputs.latest + with: + folder: _build + clean-exclude: | + *.*.*/ + main diff --git a/.github/workflows/docs-versioning.yml b/.github/workflows/docs-versioning.yml index 2a5e2e2a8..5e6b8f146 100644 --- a/.github/workflows/docs-versioning.yml +++ b/.github/workflows/docs-versioning.yml @@ -1,56 +1,59 @@ -name: Docs Versioning -on: - push: - tags: - - 'v?*.*.*' -permissions: - contents: write -jobs: - docs: - name: Generate Website for new version - runs-on: ubuntu-latest - env: - SPHINX_GITHUB_CHANGELOG_TOKEN: ${{ secrets.GITHUB_TOKEN }} - steps: - - uses: actions/checkout@v3 - - - uses: actions/setup-python@v4 - with: - python-version: '3.9' - - - name: Install dependencies - run: pip install -r docs/requirements.txt - - - name: Install Gymnasium - run: pip install .[box2d,mujoco,jax] torch - - - name: Build Envs Docs - run: python docs/_scripts/gen_mds.py && python docs/_scripts/gen_envs_display.py - - - name: Build - run: sphinx-build -b dirhtml -v docs _build - - - name: Move 404 - run: mv _build/404/index.html _build/404.html - - - name: Update 404 links - run: python docs/_scripts/move_404.py _build/404.html - - - name: Remove .doctrees - run: rm -r _build/.doctrees - - - name: Upload to GitHub Pages - uses: JamesIves/github-pages-deploy-action@v4 - with: - folder: _build - target-folder: ${{github.ref_name}} - clean: false - - - name: Upload to GitHub Pages - uses: JamesIves/github-pages-deploy-action@v4 - if: ${{ !contains(github.ref_name, 'a') }} - with: - folder: _build - clean-exclude: | - *.*.*/ - main +name: Docs Versioning + +on: + push: + tags: + - 'v?*.*.*' + - +permissions: + contents: write + +jobs: + docs: + name: Generate Website for new version + runs-on: ubuntu-latest + env: + SPHINX_GITHUB_CHANGELOG_TOKEN: ${{ secrets.GITHUB_TOKEN }} + steps: + - uses: actions/checkout@v3 + + - uses: actions/setup-python@v4 + with: + python-version: '3.9' + + - name: Install dependencies + run: pip install -r docs/requirements.txt + + - name: Install Gymnasium + run: pip install .[box2d,mujoco,jax] torch + + - name: Build Envs Docs + run: python docs/_scripts/gen_mds.py && python docs/_scripts/gen_envs_display.py + + - name: Build + run: sphinx-build -b dirhtml -v docs _build + + - name: Move 404 + run: mv _build/404/index.html _build/404.html + + - name: Update 404 links + run: python docs/_scripts/move_404.py _build/404.html + + - name: Remove .doctrees + run: rm -r _build/.doctrees + + - name: Upload to GitHub Pages + uses: JamesIves/github-pages-deploy-action@v4 + with: + folder: _build + target-folder: ${{github.ref_name}} + clean: false + + - name: Upload to GitHub Pages + uses: JamesIves/github-pages-deploy-action@v4 + if: ${{ !contains(github.ref_name, 'a') }} + with: + folder: _build + clean-exclude: | + *.*.*/ + main diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9dd1847e6..817a82a39 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,7 +51,7 @@ repos: rev: 6.3.0 hooks: - id: pydocstyle - exclude: ^(gymnasium/envs/box2d)|(gymnasium/envs/classic_control)|(gymnasium/envs/mujoco)|(gymnasium/envs/toy_text)|(tests/envs)|(tests/spaces)|(tests/utils)|(tests/vector)|(tests/wrappers)|(docs/) + exclude: ^(gymnasium/envs/box2d)|(gymnasium/envs/classic_control)|(gymnasium/envs/mujoco)|(gymnasium/envs/toy_text)|(tests/envs)|(tests/spaces)|(tests/utils)|(tests/vector)|(docs/) args: - --source - --explain diff --git a/docs/_scripts/gen_wrapper_table.py b/docs/_scripts/gen_wrapper_table.py new file mode 100644 index 000000000..4aad5d4c7 --- /dev/null +++ b/docs/_scripts/gen_wrapper_table.py @@ -0,0 +1,73 @@ +import os.path + +import gymnasium as gym + + +exclude_wrappers = {"vector"} + + +def generate_wrappers(): + wrapper_table = "" + for wrapper_name in sorted(gym.wrappers.__all__): + if wrapper_name not in exclude_wrappers: + wrapper_doc = getattr(gym.wrappers, wrapper_name).__doc__.split("\n")[0] + wrapper_table += f""" * - :class:`{wrapper_name}` + - {wrapper_doc} +""" + return wrapper_table + + +def generate_vector_wrappers(): + unique_vector_wrappers = set(gym.wrappers.vector.__all__) - set( + gym.wrappers.__all__ + ) + + vector_table = "" + for vector_name in sorted(unique_vector_wrappers): + vector_doc = getattr(gym.wrappers.vector, vector_name).__doc__.split("\n")[0] + vector_table += f""" * - :class:`{vector_name}` + - {vector_doc} +""" + return vector_table + + +if __name__ == "__main__": + gen_wrapper_table = generate_wrappers() + gen_vector_table = generate_vector_wrappers() + + page = f""" +# List of Gymnasium Wrappers + +Gymnasium provides a number of commonly used wrappers listed below. More information can be found on the particular +wrapper in the page on the wrapper type + +```{{eval-rst}} +.. py:currentmodule:: gymnasium.wrappers + +.. list-table:: + :header-rows: 1 + + * - Name + - Description +{gen_wrapper_table} +``` + +## Vector only Wrappers + +```{{eval-rst}} +.. py:currentmodule:: gymnasium.wrappers.vector + +.. list-table:: + :header-rows: 1 + + * - Name + - Description +{gen_vector_table} +``` +""" + + filename = os.path.join( + os.path.dirname(__file__), "..", "api", "wrappers", "table.md" + ) + with open(filename, "w") as file: + file.write(page) diff --git a/docs/api/env.md b/docs/api/env.md index cdef934ea..af3345b7e 100644 --- a/docs/api/env.md +++ b/docs/api/env.md @@ -1,29 +1,26 @@ --- -title: Utils +title: Env --- # Env -## gymnasium.Env - ```{eval-rst} .. autoclass:: gymnasium.Env ``` -### Methods - +## Methods ```{eval-rst} -.. autofunction:: gymnasium.Env.step -.. autofunction:: gymnasium.Env.reset -.. autofunction:: gymnasium.Env.render +.. automethod:: gymnasium.Env.step +.. automethod:: gymnasium.Env.reset +.. automethod:: gymnasium.Env.render +.. automethod:: gymnasium.Env.close ``` -### Attributes - +## Attributes ```{eval-rst} .. autoattribute:: gymnasium.Env.action_space - The Space object corresponding to valid actions, all valid actions should be contained with the space. For example, if the action space is of type `Discrete` and gives the value `Discrete(2)`, this means there are two valid discrete actions: 0 & 1. + The Space object corresponding to valid actions, all valid actions should be contained with the space. For example, if the action space is of type `Discrete` and gives the value `Discrete(2)`, this means there are two valid discrete actions: `0` & `1`. .. code:: @@ -51,29 +48,26 @@ title: Utils The render mode of the environment determined at initialisation -.. autoattribute:: gymnasium.Env.reward_range - - A tuple corresponding to the minimum and maximum possible rewards for an agent over an episode. The default reward range is set to :math:`(-\infty,+\infty)`. - .. autoattribute:: gymnasium.Env.spec - The ``EnvSpec`` of the environment normally set during :py:meth:`gymnasium.make` -``` + The :class:`EnvSpec` of the environment normally set during :py:meth:`gymnasium.make` -### Additional Methods - -```{eval-rst} -.. autofunction:: gymnasium.Env.close .. autoproperty:: gymnasium.Env.unwrapped .. autoproperty:: gymnasium.Env.np_random ``` -### Implementing environments +## Implementing environments ```{eval-rst} .. py:currentmodule:: gymnasium -When implementing an environment, the :meth:`Env.reset` and :meth:`Env.step` functions much be created describing the -dynamics of the environment. -For more information see the environment creation tutorial. +When implementing an environment, the :meth:`Env.reset` and :meth:`Env.step` functions much be created describing the dynamics of the environment. For more information see the environment creation tutorial. +``` + +## Creating environments + +```{eval-rst} +.. py:currentmodule:: gymnasium + +To create an environment, gymnasium provides :meth:`make` to initialise the environment along with several important wrappers. Furthermore, gymnasium provides :meth:`make_vec` for creating vector environments and to view all the environment that can be created use :meth:`pprint_registry`. ``` diff --git a/docs/api/experimental.md b/docs/api/experimental.md deleted file mode 100644 index fb38a089d..000000000 --- a/docs/api/experimental.md +++ /dev/null @@ -1,157 +0,0 @@ ---- -title: Experimental ---- - -# Experimental - -```{toctree} -:hidden: -experimental/functional -experimental/wrappers -experimental/vector -experimental/vector_wrappers -experimental/vector_utils -``` - -## Functional Environments - -```{eval-rst} -The gymnasium ``Env`` provides high flexibility for the implementation of individual environments however this can complicate parallelism of environments. Therefore, we propose the :class:`gymnasium.experimental.FuncEnv` where each part of environment has its own function related to it. -``` - -## Wrappers - -Gymnasium already contains a large collection of wrappers, but we believe that the wrappers can be improved to - -* (Work in progress) Support arbitrarily complex observation / action spaces. As RL has advanced, action and observation spaces are becoming more complex and the current wrappers were not implemented with this mind. -* Support for Jax-based environments. With hardware accelerated environments, i.e. Brax, written in Jax and similar PyTorch based programs, NumPy is not the only game in town anymore. Therefore, these upgrades will use [Jumpy](https://github.com/farama-Foundation/jumpy), a project developed by Farama Foundation to provide automatic compatibility for NumPy, Jax and in the future PyTorch data for a large subset of the NumPy functions. -* More wrappers. Projects like [Supersuit](https://github.com/farama-Foundation/supersuit) aimed to bring more wrappers for RL, however, many users were not aware of the wrappers, so we plan to move the wrappers into Gymnasium. If we are missing common wrappers from the list provided above, please create an issue. -* Versioning. Like environments, the implementation details of wrappers can cause changes in agent performance. Therefore, we propose adding version numbers to all wrappers, i.e., `LambaActionV0`. We don't expect these version numbers to change regularly similar to environment version numbers and should ensure that all users know when significant changes could affect your agent's performance. Additionally, we hope that this will improve reproducibility of RL in the future, this is critical for academia. -* In v28, we aim to rewrite the VectorEnv to not inherit from Env, as a result new vectorized versions of the wrappers will be provided. - -We aimed to replace the wrappers in gymnasium v0.30.0 with these experimental wrappers. - -### Observation Wrappers -```{eval-rst} -.. py:currentmodule:: gymnasium - -.. list-table:: - :header-rows: 1 - - * - Old name - - New name - * - :class:`wrappers.TransformObservation` - - :class:`experimental.wrappers.LambdaObservationV0` - * - :class:`wrappers.FilterObservation` - - :class:`experimental.wrappers.FilterObservationV0` - * - :class:`wrappers.FlattenObservation` - - :class:`experimental.wrappers.FlattenObservationV0` - * - :class:`wrappers.GrayScaleObservation` - - :class:`experimental.wrappers.GrayscaleObservationV0` - * - :class:`wrappers.ResizeObservation` - - :class:`experimental.wrappers.ResizeObservationV0` - * - `supersuit.reshape_v0 `_ - - :class:`experimental.wrappers.ReshapeObservationV0` - * - Not Implemented - - :class:`experimental.wrappers.RescaleObservationV0` - * - `supersuit.dtype_v0 `_ - - :class:`experimental.wrappers.DtypeObservationV0` - * - :class:`wrappers.PixelObservationWrapper` - - :class:`experimental.wrappers.PixelObservationV0` - * - :class:`wrappers.NormalizeObservation` - - :class:`experimental.wrappers.NormalizeObservationV0` - * - :class:`wrappers.TimeAwareObservation` - - :class:`experimental.wrappers.TimeAwareObservationV0` - * - :class:`wrappers.FrameStack` - - :class:`experimental.wrappers.FrameStackObservationV0` - * - `supersuit.delay_observations_v0 `_ - - :class:`experimental.wrappers.DelayObservationV0` -``` - -### Action Wrappers -```{eval-rst} -.. py:currentmodule:: gymnasium - -.. list-table:: - :header-rows: 1 - - * - Old name - - New name - * - `supersuit.action_lambda_v1 `_ - - :class:`experimental.wrappers.LambdaActionV0` - * - :class:`wrappers.ClipAction` - - :class:`experimental.wrappers.ClipActionV0` - * - :class:`wrappers.RescaleAction` - - :class:`experimental.wrappers.RescaleActionV0` - * - `supersuit.sticky_actions_v0 `_ - - :class:`experimental.wrappers.StickyActionV0` -``` - -### Reward Wrappers -```{eval-rst} -.. py:currentmodule:: gymnasium - -.. list-table:: - :header-rows: 1 - - * - Old name - - New name - * - :class:`wrappers.TransformReward` - - :class:`experimental.wrappers.LambdaRewardV0` - * - `supersuit.clip_reward_v0 `_ - - :class:`experimental.wrappers.ClipRewardV0` - * - :class:`wrappers.NormalizeReward` - - :class:`experimental.wrappers.NormalizeRewardV1` -``` - -### Common Wrappers - -```{eval-rst} -.. py:currentmodule:: gymnasium - -.. list-table:: - :header-rows: 1 - - * - Old name - - New name - * - :class:`wrappers.AutoResetWrapper` - - :class:`experimental.wrappers.AutoresetV0` - * - :class:`wrappers.PassiveEnvChecker` - - :class:`experimental.wrappers.PassiveEnvCheckerV0` - * - :class:`wrappers.OrderEnforcing` - - :class:`experimental.wrappers.OrderEnforcingV0` - * - :class:`wrappers.EnvCompatibility` - - Moved to `shimmy `_ - * - :class:`wrappers.RecordEpisodeStatistics` - - :class:`experimental.wrappers.RecordEpisodeStatisticsV0` - * - :class:`wrappers.AtariPreprocessing` - - :class:`experimental.wrappers.AtariPreprocessingV0` -``` - -### Rendering Wrappers - -```{eval-rst} -.. py:currentmodule:: gymnasium - -.. list-table:: - :header-rows: 1 - - * - Old name - - New name - * - :class:`wrappers.RecordVideo` - - :class:`experimental.wrappers.RecordVideoV0` - * - :class:`wrappers.HumanRendering` - - :class:`experimental.wrappers.HumanRenderingV0` - * - :class:`wrappers.RenderCollection` - - :class:`experimental.wrappers.RenderCollectionV0` -``` - -### Environment data conversion - -```{eval-rst} -.. py:currentmodule:: gymnasium - -* :class:`experimental.wrappers.JaxToNumpyV0` -* :class:`experimental.wrappers.JaxToTorchV0` -* :class:`experimental.wrappers.NumpyToTorchV0` -``` diff --git a/docs/api/experimental/functional.md b/docs/api/experimental/functional.md index 451696aac..e69de29bb 100644 --- a/docs/api/experimental/functional.md +++ b/docs/api/experimental/functional.md @@ -1,37 +0,0 @@ ---- -title: Functional ---- - -# Functional Environment - -## gymnasium.experimental.FuncEnv - -```{eval-rst} -.. autoclass:: gymnasium.experimental.functional.FuncEnv - -.. autofunction:: gymnasium.experimental.functional.FuncEnv.initial -.. autofunction:: gymnasium.experimental.functional.FuncEnv.transition - -.. autofunction:: gymnasium.experimental.functional.FuncEnv.observation -.. autofunction:: gymnasium.experimental.functional.FuncEnv.reward -.. autofunction:: gymnasium.experimental.functional.FuncEnv.terminal - -.. autofunction:: gymnasium.experimental.functional.FuncEnv.state_info -.. autofunction:: gymnasium.experimental.functional.FuncEnv.step_info - -.. autofunction:: gymnasium.experimental.functional.FuncEnv.transform - -.. autofunction:: gymnasium.experimental.functional.FuncEnv.render_image -.. autofunction:: gymnasium.experimental.functional.FuncEnv.render_init -.. autofunction:: gymnasium.experimental.functional.FuncEnv.render_close -``` - -## gymnasium.experimental.func2env.FunctionalJaxCompatibilityEnv - -```{eval-rst} -.. autoclass:: gymnasium.experimental.functional_jax_env.FunctionalJaxEnv - -.. autofunction:: gymnasium.experimental.functional_jax_env.FunctionalJaxEnv.reset -.. autofunction:: gymnasium.experimental.functional_jax_env.FunctionalJaxEnv.step -.. autofunction:: gymnasium.experimental.functional_jax_env.FunctionalJaxEnv.render -``` diff --git a/docs/api/experimental/vector.md b/docs/api/experimental/vector.md deleted file mode 100644 index 07f9478d9..000000000 --- a/docs/api/experimental/vector.md +++ /dev/null @@ -1,39 +0,0 @@ ---- -title: Vector ---- - -# Vectorizing Environment - -## gymnasium.experimental.VectorEnv - -```{eval-rst} -.. autoclass:: gymnasium.experimental.vector.VectorEnv -.. autofunction:: gymnasium.experimental.vector.VectorEnv.reset -.. autofunction:: gymnasium.experimental.vector.VectorEnv.step -.. autofunction:: gymnasium.experimental.vector.VectorEnv.close -.. autofunction:: gymnasium.experimental.vector.VectorEnv.reset -``` - -## gymnasium.experimental.vector.AsyncVectorEnv - -```{eval-rst} -.. autoclass:: gymnasium.experimental.vector.AsyncVectorEnv -.. autofunction:: gymnasium.experimental.vector.AsyncVectorEnv.reset -.. autofunction:: gymnasium.experimental.vector.AsyncVectorEnv.step -.. autofunction:: gymnasium.experimental.vector.AsyncVectorEnv.close -.. autofunction:: gymnasium.experimental.vector.AsyncVectorEnv.call -.. autofunction:: gymnasium.experimental.vector.AsyncVectorEnv.get_attr -.. autofunction:: gymnasium.experimental.vector.AsyncVectorEnv.set_attr -``` - -## gymnasium.experimental.vector.SyncVectorEnv - -```{eval-rst} -.. autoclass:: gymnasium.experimental.vector.SyncVectorEnv -.. autofunction:: gymnasium.experimental.vector.SyncVectorEnv.reset -.. autofunction:: gymnasium.experimental.vector.SyncVectorEnv.step -.. autofunction:: gymnasium.experimental.vector.SyncVectorEnv.close -.. autofunction:: gymnasium.experimental.vector.SyncVectorEnv.call -.. autofunction:: gymnasium.experimental.vector.SyncVectorEnv.get_attr -.. autofunction:: gymnasium.experimental.vector.SyncVectorEnv.set_attr -``` diff --git a/docs/api/experimental/vector_utils.md b/docs/api/experimental/vector_utils.md deleted file mode 100644 index 6879b9818..000000000 --- a/docs/api/experimental/vector_utils.md +++ /dev/null @@ -1,29 +0,0 @@ ---- -title: Vector Utility ---- - -# Utility functions for vectorisation - -## Spaces utility functions - -```{eval-rst} -.. autofunction:: gymnasium.experimental.vector.utils.batch_space -.. autofunction:: gymnasium.experimental.vector.utils.concatenate -.. autofunction:: gymnasium.experimental.vector.utils.iterate -.. autofunction:: gymnasium.experimental.vector.utils.create_empty_array -``` - -## Shared memory functions - -```{eval-rst} -.. autofunction:: gymnasium.experimental.vector.utils.create_shared_memory -.. autofunction:: gymnasium.experimental.vector.utils.read_from_shared_memory -.. autofunction:: gymnasium.experimental.vector.utils.write_to_shared_memory -``` - -## Miscellaneous - -```{eval-rst} -.. autofunction:: gymnasium.experimental.vector.utils.CloudpickleWrapper -.. autofunction:: gymnasium.experimental.vector.utils.clear_mpi_env_vars -``` diff --git a/docs/api/experimental/vector_wrappers.md b/docs/api/experimental/vector_wrappers.md deleted file mode 100644 index 68abc3c37..000000000 --- a/docs/api/experimental/vector_wrappers.md +++ /dev/null @@ -1,53 +0,0 @@ ---- -title: Vector Wrappers ---- - -# Vector Environment Wrappers - -```{eval-rst} -.. autoclass:: gymnasium.experimental.vector.VectorWrapper -``` - -## Vector Observation Wrappers - -```{eval-rst} -.. autoclass:: gymnasium.experimental.vector.VectorObservationWrapper -.. autoclass:: gymnasium.experimental.wrappers.vector.LambdaObservationV0 -.. autoclass:: gymnasium.experimental.wrappers.vector.FilterObservationV0 -.. autoclass:: gymnasium.experimental.wrappers.vector.FlattenObservationV0 -.. autoclass:: gymnasium.experimental.wrappers.vector.GrayscaleObservationV0 -.. autoclass:: gymnasium.experimental.wrappers.vector.ResizeObservationV0 -.. autoclass:: gymnasium.experimental.wrappers.vector.ReshapeObservationV0 -.. autoclass:: gymnasium.experimental.wrappers.vector.RescaleObservationV0 -.. autoclass:: gymnasium.experimental.wrappers.vector.DtypeObservationV0 -``` - -## Vector Action Wrappers - -```{eval-rst} -.. autoclass:: gymnasium.experimental.vector.VectorActionWrapper -.. autoclass:: gymnasium.experimental.wrappers.vector.LambdaActionV0 -.. autoclass:: gymnasium.experimental.wrappers.vector.ClipActionV0 -.. autoclass:: gymnasium.experimental.wrappers.vector.RescaleActionV0 -``` - -## Vector Reward Wrappers - -```{eval-rst} -.. autoclass:: gymnasium.experimental.vector.VectorRewardWrapper -.. autoclass:: gymnasium.experimental.wrappers.vector.LambdaRewardV0 -.. autoclass:: gymnasium.experimental.wrappers.vector.ClipRewardV0 -``` - -## More Vector Wrappers - -```{eval-rst} -.. autoclass:: gymnasium.experimental.wrappers.vector.RecordEpisodeStatisticsV0 -.. autoclass:: gymnasium.experimental.wrappers.vector.DictInfoToListV0 -.. autoclass:: gymnasium.experimental.wrappers.vector.VectorizeLambdaObservationV0 -.. autoclass:: gymnasium.experimental.wrappers.vector.VectorizeLambdaActionV0 -.. autoclass:: gymnasium.experimental.wrappers.vector.VectorizeLambdaRewardV0 -.. autoclass:: gymnasium.experimental.wrappers.vector.JaxToNumpyV0 -.. autoclass:: gymnasium.experimental.wrappers.vector.JaxToTorchV0 -.. autoclass:: gymnasium.experimental.wrappers.vector.NumpyToTorchV0 -``` diff --git a/docs/api/experimental/wrappers.md b/docs/api/experimental/wrappers.md deleted file mode 100644 index 60644f71f..000000000 --- a/docs/api/experimental/wrappers.md +++ /dev/null @@ -1,65 +0,0 @@ ---- -title: Wrappers ---- - -# Wrappers - -## Observation Wrappers - -```{eval-rst} -.. autoclass:: gymnasium.experimental.wrappers.LambdaObservationV0 -.. autoclass:: gymnasium.experimental.wrappers.FilterObservationV0 -.. autoclass:: gymnasium.experimental.wrappers.FlattenObservationV0 -.. autoclass:: gymnasium.experimental.wrappers.GrayscaleObservationV0 -.. autoclass:: gymnasium.experimental.wrappers.ResizeObservationV0 -.. autoclass:: gymnasium.experimental.wrappers.ReshapeObservationV0 -.. autoclass:: gymnasium.experimental.wrappers.RescaleObservationV0 -.. autoclass:: gymnasium.experimental.wrappers.DtypeObservationV0 -.. autoclass:: gymnasium.experimental.wrappers.PixelObservationV0 -.. autoclass:: gymnasium.experimental.wrappers.NormalizeObservationV0 -.. autoclass:: gymnasium.experimental.wrappers.TimeAwareObservationV0 -.. autoclass:: gymnasium.experimental.wrappers.FrameStackObservationV0 -.. autoclass:: gymnasium.experimental.wrappers.DelayObservationV0 -.. autoclass:: gymnasium.experimental.wrappers.AtariPreprocessingV0 -``` - -## Action Wrappers - -```{eval-rst} -.. autoclass:: gymnasium.experimental.wrappers.LambdaActionV0 -.. autoclass:: gymnasium.experimental.wrappers.ClipActionV0 -.. autoclass:: gymnasium.experimental.wrappers.RescaleActionV0 -.. autoclass:: gymnasium.experimental.wrappers.StickyActionV0 -``` - -## Reward Wrappers - -```{eval-rst} -.. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0 -.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0 -.. autoclass:: gymnasium.experimental.wrappers.NormalizeRewardV1 -``` - -## Other Wrappers - -```{eval-rst} -.. autoclass:: gymnasium.experimental.wrappers.AutoresetV0 -.. autoclass:: gymnasium.experimental.wrappers.PassiveEnvCheckerV0 -.. autoclass:: gymnasium.experimental.wrappers.OrderEnforcingV0 -.. autoclass:: gymnasium.experimental.wrappers.RecordEpisodeStatisticsV0 -``` - -## Rendering Wrappers - -```{eval-rst} -.. autoclass:: gymnasium.experimental.wrappers.RecordVideoV0 -.. autoclass:: gymnasium.experimental.wrappers.HumanRenderingV0 -.. autoclass:: gymnasium.experimental.wrappers.RenderCollectionV0 -``` - -## Environment data conversion -```{eval-rst} -.. autoclass:: gymnasium.experimental.wrappers.JaxToNumpyV0 -.. autoclass:: gymnasium.experimental.wrappers.JaxToTorchV0 -.. autoclass:: gymnasium.experimental.wrappers.NumpyToTorchV0 -``` diff --git a/docs/api/functional.md b/docs/api/functional.md new file mode 100644 index 000000000..5af602560 --- /dev/null +++ b/docs/api/functional.md @@ -0,0 +1,34 @@ +--- +title: Functional +--- + +# Functional Env + +```{eval-rst} +.. autoclass:: gymnasium.functional.FuncEnv + + .. automethod:: gymnasium.functional.FuncEnv.transform + + .. automethod:: gymnasium.functional.FuncEnv.initial + .. automethod:: gymnasium.functional.FuncEnv.initial_info + + .. automethod:: gymnasium.functional.FuncEnv.transition + .. automethod:: gymnasium.functional.FuncEnv.observation + .. automethod:: gymnasium.functional.FuncEnv.reward + .. automethod:: gymnasium.functional.FuncEnv.terminal + .. automethod:: gymnasium.functional.FuncEnv.transition_info + + .. automethod:: gymnasium.functional.FuncEnv.render_image + .. automethod:: gymnasium.functional.FuncEnv.render_initialise + .. automethod:: gymnasium.functional.FuncEnv.render_close +``` + +## Converting Jax-based Functional environments to standard Env + +```{eval-rst} +.. autoclass:: gymnasium.utils.functional_jax_env.FunctionalJaxEnv + + .. automethod:: gymnasium.utils.functional_jax_env.FunctionalJaxEnv.reset + .. automethod:: gymnasium.utils.functional_jax_env.FunctionalJaxEnv.step + .. automethod:: gymnasium.utils.functional_jax_env.FunctionalJaxEnv.render +``` diff --git a/docs/api/registry.md b/docs/api/registry.md index 422472aae..082794295 100644 --- a/docs/api/registry.md +++ b/docs/api/registry.md @@ -2,14 +2,13 @@ title: Registry --- -# Register and Make +# Make and register ```{eval-rst} Gymnasium allows users to automatically load environments, pre-wrapped with several important wrappers through the :meth:`gymnasium.make` function. To do this, the environment must be registered prior with :meth:`gymnasium.register`. To get the environment specifications for a registered environment, use :meth:`gymnasium.spec` and to print the whole registry, use :meth:`gymnasium.pprint_registry`. -``` -```{eval-rst} .. autofunction:: gymnasium.make +.. autofunction:: gymnasium.make_vec .. autofunction:: gymnasium.register .. autofunction:: gymnasium.spec .. autofunction:: gymnasium.pprint_registry @@ -19,6 +18,7 @@ Gymnasium allows users to automatically load environments, pre-wrapped with seve ```{eval-rst} .. autoclass:: gymnasium.envs.registration.EnvSpec +.. autoclass:: gymnasium.envs.registration.WrapperSpec .. attribute:: gymnasium.envs.registration.registry The Global registry for gymnasium which is where environment specifications are stored by :meth:`gymnasium.register` and from which :meth:`gymnasium.make` is used to create environments. @@ -36,5 +36,4 @@ Gymnasium allows users to automatically load environments, pre-wrapped with seve .. autofunction:: gymnasium.envs.registration.find_highest_version .. autofunction:: gymnasium.envs.registration.namespace .. autofunction:: gymnasium.envs.registration.load_env_creator -.. autofunction:: gymnasium.envs.registration.load_plugin_envs ``` diff --git a/docs/api/spaces.md b/docs/api/spaces.md index 612f0cee6..791cf43ac 100644 --- a/docs/api/spaces.md +++ b/docs/api/spaces.md @@ -9,39 +9,38 @@ title: Spaces spaces/fundamental spaces/composite spaces/utils -spaces/vector_utils +vector/utils ``` ```{eval-rst} .. automodule:: gymnasium.spaces -``` -## The Base Class - -```{eval-rst} .. autoclass:: gymnasium.spaces.Space ``` -### Attributes - +## Attributes ```{eval-rst} -.. autoproperty:: gymnasium.spaces.space.Space.shape +.. py:currentmodule:: gymnasium.spaces + +.. autoproperty:: Space.shape .. property:: Space.dtype Return the data type of this space. -.. autoproperty:: gymnasium.spaces.space.Space.is_np_flattenable +.. autoproperty:: Space.is_np_flattenable +.. autoproperty:: Space.np_random ``` -### Methods - +## Methods Each space implements the following functions: ```{eval-rst} -.. autofunction:: gymnasium.spaces.space.Space.sample -.. autofunction:: gymnasium.spaces.space.Space.contains -.. autofunction:: gymnasium.spaces.space.Space.seed -.. autofunction:: gymnasium.spaces.space.Space.to_jsonable -.. autofunction:: gymnasium.spaces.space.Space.from_jsonable +.. py:currentmodule:: gymnasium.spaces + +.. automethod:: Space.sample +.. automethod:: Space.contains +.. automethod:: Space.seed +.. automethod:: Space.to_jsonable +.. automethod:: Space.from_jsonable ``` ## Fundamental Spaces @@ -49,13 +48,13 @@ Each space implements the following functions: Gymnasium has a number of fundamental spaces that are used as building boxes for more complex spaces. ```{eval-rst} -.. currentmodule:: gymnasium.spaces +.. py:currentmodule:: gymnasium.spaces -* :py:class:`Box` - Supports continuous (and discrete) vectors or matrices, used for vector observations, images, etc -* :py:class:`Discrete` - Supports a single discrete number of values with an optional start for the values -* :py:class:`MultiBinary` - Supports single or matrices of binary values, used for holding down a button or if an agent has an object -* :py:class:`MultiDiscrete` - Supports multiple discrete values with multiple axes, used for controller actions -* :py:class:`Text` - Supports strings, used for passing agent messages, mission details, etc +* :class:`Box` - Supports continuous (and discrete) vectors or matrices, used for vector observations, images, etc +* :class:`Discrete` - Supports a single discrete number of values with an optional start for the values +* :class:`MultiBinary` - Supports single or matrices of binary values, used for holding down a button or if an agent has an object +* :class:`MultiDiscrete` - Supports multiple discrete values with multiple axes, used for controller actions +* :class:`Text` - Supports strings, used for passing agent messages, mission details, etc ``` ## Composite Spaces @@ -63,37 +62,41 @@ Gymnasium has a number of fundamental spaces that are used as building boxes for Often environment spaces require joining fundamental spaces together for vectorised environments, separate agents or readability of the space. ```{eval-rst} -* :py:class:`Dict` - Supports a dictionary of keys and subspaces, used for a fixed number of unordered spaces -* :py:class:`Tuple` - Supports a tuple of subspaces, used for multiple for a fixed number of ordered spaces -* :py:class:`Sequence` - Supports a variable number of instances of a single subspace, used for entities spaces or selecting a variable number of actions +.. py:currentmodule:: gymnasium.spaces + +* :class:`Dict` - Supports a dictionary of keys and subspaces, used for a fixed number of unordered spaces +* :class:`Tuple` - Supports a tuple of subspaces, used for multiple for a fixed number of ordered spaces +* :class:`Sequence` - Supports a variable number of instances of a single subspace, used for entities spaces or selecting a variable number of actions * :py:class:`Graph` - Supports graph based actions or observations with discrete or continuous nodes and edge values. ``` -## Utils +## Utility functions Gymnasium contains a number of helpful utility functions for flattening and unflattening spaces. This can be important for passing information to neural networks. ```{eval-rst} -* :py:class:`utils.flatdim` - The number of dimensions the flattened space will contain -* :py:class:`utils.flatten_space` - Flattens a space for which the `flattened` space instances will contain -* :py:class:`utils.flatten` - Flattens an instance of a space that is contained within the flattened version of the space -* :py:class:`utils.unflatten` - The reverse of the `flatten_space` function +.. py:currentmodule:: gymnasium.spaces + +* :class:`utils.flatdim` - The number of dimensions the flattened space will contain +* :class:`utils.flatten_space` - Flattens a space for which the :class:`utils.flattened` space instances will contain +* :class:`utils.flatten` - Flattens an instance of a space that is contained within the flattened version of the space +* :class:`utils.unflatten` - The reverse of the :class:`utils.flatten_space` function ``` -## Vector Utils +## Vector Utility functions When vectorizing environments, it is necessary to modify the observation and action spaces for new batched spaces sizes. Therefore, Gymnasium provides a number of additional functions used when using a space with a Vector environment. ```{eval-rst} -.. currentmodule:: gymnasium +.. py:currentmodule:: gymnasium -* :py:class:`vector.utils.batch_space` -* :py:class:`vector.utils.concatenate` -* :py:class:`vector.utils.iterate` -* :py:class:`vector.utils.create_empty_array` -* :py:class:`vector.utils.create_shared_memory` -* :py:class:`vector.utils.read_from_shared_memory` -* :py:class:`vector.utils.write_to_shared_memory` +* :class:`vector.utils.batch_space` - Transforms a space into the equivalent space for ``n`` users +* :class:`vector.utils.concatenate` - Concatenates a space's samples into a pre-generated space +* :class:`vector.utils.iterate` - Iterate over the batched space's samples +* :class:`vector.utils.create_empty_array` - Creates an empty sample for an space (generally used with ``concatenate``) +* :class:`vector.utils.create_shared_memory` - Creates a shared memory for asynchronous (multiprocessing) environment +* :class:`vector.utils.read_from_shared_memory` - Reads a shared memory for asynchronous (multiprocessing) environment +* :class:`vector.utils.write_to_shared_memory` - Write to a shared memory for asynchronous (multiprocessing) environment ``` diff --git a/docs/api/spaces/composite.md b/docs/api/spaces/composite.md index 5b99704fa..b43a5b0f5 100644 --- a/docs/api/spaces/composite.md +++ b/docs/api/spaces/composite.md @@ -1,37 +1,24 @@ # Composite Spaces -## Dict ```{eval-rst} .. autoclass:: gymnasium.spaces.Dict -.. automethod:: gymnasium.spaces.Dict.sample -.. automethod:: gymnasium.spaces.Dict.seed -``` + .. automethod:: gymnasium.spaces.Dict.sample + .. automethod:: gymnasium.spaces.Dict.seed -## Tuple - -```{eval-rst} .. autoclass:: gymnasium.spaces.Tuple -.. automethod:: gymnasium.spaces.Tuple.sample -.. automethod:: gymnasium.spaces.Tuple.seed -``` + .. automethod:: gymnasium.spaces.Tuple.sample + .. automethod:: gymnasium.spaces.Tuple.seed -## Sequence - -```{eval-rst} .. autoclass:: gymnasium.spaces.Sequence -.. automethod:: gymnasium.spaces.Sequence.sample -.. automethod:: gymnasium.spaces.Sequence.seed -``` + .. automethod:: gymnasium.spaces.Sequence.sample + .. automethod:: gymnasium.spaces.Sequence.seed -## Graph - -```{eval-rst} .. autoclass:: gymnasium.spaces.Graph -.. automethod:: gymnasium.spaces.Graph.sample -.. automethod:: gymnasium.spaces.Graph.seed + .. automethod:: gymnasium.spaces.Graph.sample + .. automethod:: gymnasium.spaces.Graph.seed ``` diff --git a/docs/api/spaces/fundamental.md b/docs/api/spaces/fundamental.md index 190f73790..a726af32f 100644 --- a/docs/api/spaces/fundamental.md +++ b/docs/api/spaces/fundamental.md @@ -4,46 +4,30 @@ title: Fundamental Spaces # Fundamental Spaces -## Box - ```{eval-rst} .. autoclass:: gymnasium.spaces.Box -.. automethod:: gymnasium.spaces.Box.sample -.. automethod:: gymnasium.spaces.Box.seed -.. automethod:: gymnasium.spaces.Box.is_bounded -``` + .. automethod:: gymnasium.spaces.Box.sample + .. automethod:: gymnasium.spaces.Box.seed + .. automethod:: gymnasium.spaces.Box.is_bounded -## Discrete - -```{eval-rst} .. autoclass:: gymnasium.spaces.Discrete -.. automethod:: gymnasium.spaces.Discrete.sample -.. automethod:: gymnasium.spaces.Discrete.seed -``` -## MultiBinary + .. automethod:: gymnasium.spaces.Discrete.sample + .. automethod:: gymnasium.spaces.Discrete.seed -```{eval-rst} .. autoclass:: gymnasium.spaces.MultiBinary -.. automethod:: gymnasium.spaces.MultiBinary.sample -.. automethod:: gymnasium.spaces.MultiBinary.seed -``` -## MultiDiscrete + .. automethod:: gymnasium.spaces.MultiBinary.sample + .. automethod:: gymnasium.spaces.MultiBinary.seed -```{eval-rst} .. autoclass:: gymnasium.spaces.MultiDiscrete -.. automethod:: gymnasium.spaces.MultiDiscrete.sample -.. automethod:: gymnasium.spaces.MultiDiscrete.seed -``` + .. automethod:: gymnasium.spaces.MultiDiscrete.sample + .. automethod:: gymnasium.spaces.MultiDiscrete.seed -## Text - -```{eval-rst} .. autoclass:: gymnasium.spaces.Text -.. automethod:: gymnasium.spaces.Text.sample -.. automethod:: gymnasium.spaces.Text.seed + .. automethod:: gymnasium.spaces.Text.sample + .. automethod:: gymnasium.spaces.Text.seed ``` diff --git a/docs/api/utils.md b/docs/api/utils.md index 29f1401ce..124b7fc33 100644 --- a/docs/api/utils.md +++ b/docs/api/utils.md @@ -1,8 +1,20 @@ --- -title: Utils +title: Utility functions --- -# Utils +# Utility functions + +## Seeding + +```{eval-rst} +.. autofunction:: gymnasium.utils.seeding.np_random +``` + +## Environment Checking + +```{eval-rst} +.. autofunction:: gymnasium.utils.env_checker.check_env +``` ## Visualization @@ -17,6 +29,12 @@ title: Utils .. automethod:: process_event ``` +## Environment pickling + +```{eval-rst} +.. autoclass:: gymnasium.utils.ezpickle.EzPickle +``` + ## Save Rendering Videos ```{eval-rst} @@ -31,15 +49,3 @@ title: Utils .. autofunction:: gymnasium.utils.step_api_compatibility.convert_to_terminated_truncated_step_api .. autofunction:: gymnasium.utils.step_api_compatibility.convert_to_done_step_api ``` - -## Seeding - -```{eval-rst} -.. autofunction:: gymnasium.utils.seeding.np_random -``` - -## Environment Checking - -```{eval-rst} -.. autofunction:: gymnasium.utils.env_checker.check_env -``` diff --git a/docs/api/vector.md b/docs/api/vector.md index 50dedbe5c..e0db3e8ae 100644 --- a/docs/api/vector.md +++ b/docs/api/vector.md @@ -2,7 +2,15 @@ title: Vector --- -# Vector +# Vector environments + +```{toctree} +:hidden: +vector/wrappers +vector/async_vector_env +vector/sync_vector_env +vector/utils +``` ## Gymnasium.vector.VectorEnv @@ -14,62 +22,47 @@ title: Vector ```{eval-rst} .. automethod:: gymnasium.vector.VectorEnv.reset - .. automethod:: gymnasium.vector.VectorEnv.step - .. automethod:: gymnasium.vector.VectorEnv.close ``` ### Attributes ```{eval-rst} -.. attribute:: action_space +.. autoattribute:: gymnasium.vector.VectorEnv.num_envs - The (batched) action space. The input actions of `step` must be valid elements of `action_space`.:: + The number of sub-environments in the vector environment. - >>> envs = gymnasium.vector.make("CartPole-v1", num_envs=3) - >>> envs.action_space - MultiDiscrete([2 2 2]) +.. autoattribute:: gymnasium.vector.VectorEnv.action_space -.. attribute:: observation_space + The (batched) action space. The input actions of `step` must be valid elements of `action_space`. - The (batched) observation space. The observations returned by `reset` and `step` are valid elements of `observation_space`.:: +.. autoattribute:: gymnasium.vector.VectorEnv.observation_space - >>> envs = gymnasium.vector.make("CartPole-v1", num_envs=3) - >>> envs.observation_space - Box([[-4.8 ...]], [[4.8 ...]], (3, 4), float32) + The (batched) observation space. The observations returned by `reset` and `step` are valid elements of `observation_space`. -.. attribute:: single_action_space +.. autoattribute:: gymnasium.vector.VectorEnv.single_action_space - The action space of an environment copy.:: + The action space of a sub-environment. - >>> envs = gymnasium.vector.make("CartPole-v1", num_envs=3) - >>> envs.single_action_space - Discrete(2) +.. autoattribute:: gymnasium.vector.VectorEnv.single_observation_space -.. attribute:: single_observation_space + The observation space of an environment copy. - The observation space of an environment copy.:: +.. autoattribute:: gymnasium.vector.VectorEnv.spec - >>> envs = gymnasium.vector.make("CartPole-v1", num_envs=3) - >>> envs.single_observation_space - Box([-4.8 ...], [4.8 ...], (4,), float32) + The ``EnvSpec`` of the environment normally set during :py:meth:`gymnasium.make_vec` +``` + +### Additional Methods + +```{eval-rst} +.. autoproperty:: gymnasium.vector.VectorEnv.unwrapped +.. autoproperty:: gymnasium.vector.VectorEnv.np_random ``` ## Making Vector Environments ```{eval-rst} -.. autofunction:: gymnasium.vector.make -``` - -## Async Vector Env - -```{eval-rst} -.. autoclass:: gymnasium.vector.AsyncVectorEnv -``` - -## Sync Vector Env - -```{eval-rst} -.. autoclass:: gymnasium.vector.SyncVectorEnv +To create vector environments, gymnasium provides :func:`gymnasium.make_vec` as an equivalent function to :func:`gymnasium.make`. ``` diff --git a/docs/api/vector/async_vector_env.md b/docs/api/vector/async_vector_env.md new file mode 100644 index 000000000..a0368419e --- /dev/null +++ b/docs/api/vector/async_vector_env.md @@ -0,0 +1,13 @@ +# AsyncVectorEnv + +```{eval-rst} +.. autoclass:: gymnasium.vector.AsyncVectorEnv + + .. automethod:: gymnasium.vector.AsyncVectorEnv.reset + .. automethod:: gymnasium.vector.AsyncVectorEnv.step + .. automethod:: gymnasium.vector.AsyncVectorEnv.close + + .. automethod:: gymnasium.vector.AsyncVectorEnv.call + .. automethod:: gymnasium.vector.AsyncVectorEnv.get_attr + .. automethod:: gymnasium.vector.AsyncVectorEnv.set_attr +``` diff --git a/docs/api/vector/sync_vector_env.md b/docs/api/vector/sync_vector_env.md new file mode 100644 index 000000000..3855e4820 --- /dev/null +++ b/docs/api/vector/sync_vector_env.md @@ -0,0 +1,13 @@ +# SyncVectorEnv + +```{eval-rst} +.. autoclass:: gymnasium.vector.SyncVectorEnv + + .. automethod:: gymnasium.vector.SyncVectorEnv.reset + .. automethod:: gymnasium.vector.SyncVectorEnv.step + .. automethod:: gymnasium.vector.SyncVectorEnv.close + + .. automethod:: gymnasium.vector.SyncVectorEnv.call + .. automethod:: gymnasium.vector.SyncVectorEnv.get_attr + .. automethod:: gymnasium.vector.SyncVectorEnv.set_attr +``` diff --git a/docs/api/spaces/vector_utils.md b/docs/api/vector/utils.md similarity index 66% rename from docs/api/spaces/vector_utils.md rename to docs/api/vector/utils.md index ed1e8b43a..481de7551 100644 --- a/docs/api/spaces/vector_utils.md +++ b/docs/api/vector/utils.md @@ -1,20 +1,25 @@ ---- -title: Vector Utils ---- +# Utility functions -# Spaces Vector Utils +## Vectorizing Spaces ```{eval-rst} .. autofunction:: gymnasium.vector.utils.batch_space .. autofunction:: gymnasium.vector.utils.concatenate .. autofunction:: gymnasium.vector.utils.iterate +.. autofunction:: gymnasium.vector.utils.create_empty_array ``` -## Shared Memory Utils +## Shared Memory for a Space ```{eval-rst} -.. autofunction:: gymnasium.vector.utils.create_empty_array .. autofunction:: gymnasium.vector.utils.create_shared_memory .. autofunction:: gymnasium.vector.utils.read_from_shared_memory .. autofunction:: gymnasium.vector.utils.write_to_shared_memory ``` + +## Miscellaneous + +```{eval-rst} +.. autofunction:: gymnasium.vector.utils.CloudpickleWrapper +.. autofunction:: gymnasium.vector.utils.clear_mpi_env_vars +``` diff --git a/docs/api/vector/wrappers.md b/docs/api/vector/wrappers.md new file mode 100644 index 000000000..2e3272c74 --- /dev/null +++ b/docs/api/vector/wrappers.md @@ -0,0 +1,26 @@ +--- +title: Vector Wrappers +--- + +# Vector Wrappers + +```{eval-rst} +.. autoclass:: gymnasium.vector.VectorWrapper + + .. automethod:: gymnasium.vector.VectorWrapper.step + .. automethod:: gymnasium.vector.VectorWrapper.reset + .. automethod:: gymnasium.vector.VectorWrapper.close + +.. autoclass:: gymnasium.vector.VectorObservationWrapper + + .. automethod:: gymnasium.vector.VectorObservationWrapper.vector_observation + .. automethod:: gymnasium.vector.VectorObservationWrapper.single_observation + +.. autoclass:: gymnasium.vector.VectorActionWrapper + + .. automethod:: gymnasium.vector.VectorActionWrapper.actions + +.. autoclass:: gymnasium.vector.VectorRewardWrapper + + .. automethod:: gymnasium.vector.VectorRewardWrapper.rewards +``` diff --git a/docs/api/wrappers.md b/docs/api/wrappers.md index be4be4a3f..000faf061 100644 --- a/docs/api/wrappers.md +++ b/docs/api/wrappers.md @@ -6,134 +6,47 @@ title: Wrapper ```{toctree} :hidden: + +wrappers/table wrappers/misc_wrappers wrappers/action_wrappers wrappers/observation_wrappers wrappers/reward_wrappers +wrappers/vector_wrappers ``` ```{eval-rst} .. automodule:: gymnasium.wrappers - ``` -## gymnasium.Wrapper ```{eval-rst} .. autoclass:: gymnasium.Wrapper ``` -### Methods - +## Methods ```{eval-rst} -.. autofunction:: gymnasium.Wrapper.step -.. autofunction:: gymnasium.Wrapper.reset -.. autofunction:: gymnasium.Wrapper.close +.. automethod:: gymnasium.Wrapper.step +.. automethod:: gymnasium.Wrapper.reset +.. automethod:: gymnasium.Wrapper.render +.. automethod:: gymnasium.Wrapper.close +.. automethod:: gymnasium.Wrapper.wrapper_spec +.. automethod:: gymnasium.Wrapper.get_wrapper_attr +.. automethod:: gymnasium.Wrapper.set_wrapper_attr ``` -### Attributes - +## Attributes ```{eval-rst} -.. autoproperty:: gymnasium.Wrapper.action_space -.. autoproperty:: gymnasium.Wrapper.observation_space -.. autoproperty:: gymnasium.Wrapper.reward_range -.. autoproperty:: gymnasium.Wrapper.spec -.. autoproperty:: gymnasium.Wrapper.metadata -.. autoproperty:: gymnasium.Wrapper.np_random -.. attribute:: gymnasium.Wrapper.env +.. autoattribute:: gymnasium.Wrapper.env The environment (one level underneath) this wrapper. - This may itself be a wrapped environment. - To obtain the environment underneath all layers of wrappers, use :attr:`gymnasium.Wrapper.unwrapped`. + This may itself be a wrapped environment. To obtain the environment underneath all layers of wrappers, use :attr:`gymnasium.Wrapper.unwrapped`. +.. autoproperty:: gymnasium.Wrapper.action_space +.. autoproperty:: gymnasium.Wrapper.observation_space +.. autoproperty:: gymnasium.Wrapper.spec +.. autoproperty:: gymnasium.Wrapper.metadata +.. autoproperty:: gymnasium.Wrapper.np_random .. autoproperty:: gymnasium.Wrapper.unwrapped ``` - -## Gymnasium Wrappers - -Gymnasium provides a number of commonly used wrappers listed below. More information can be found on the particular -wrapper in the page on the wrapper type - -```{eval-rst} -.. py:currentmodule:: gymnasium.wrappers - -.. list-table:: - :header-rows: 1 - - * - Name - - Type - - Description - * - :class:`AtariPreprocessing` - - Misc Wrapper - - Implements the common preprocessing applied to Atari environments - * - :class:`AutoResetWrapper` - - Misc Wrapper - - The wrapped environment will automatically reset when the terminated or truncated state is reached. - * - :class:`ClipAction` - - Action Wrapper - - Clip the continuous action to the valid bound specified by the environment's `action_space` - * - :class:`EnvCompatibility` - - Misc Wrapper - - Provides compatibility for environments written in the OpenAI Gym v0.21 API to look like Gymnasium environments - * - :class:`FilterObservation` - - Observation Wrapper - - Filters a dictionary observation spaces to only requested keys - * - :class:`FlattenObservation` - - Observation Wrapper - - An Observation wrapper that flattens the observation - * - :class:`FrameStack` - - Observation Wrapper - - AnObservation wrapper that stacks the observations in a rolling manner. - * - :class:`GrayScaleObservation` - - Observation Wrapper - - Convert the image observation from RGB to gray scale. - * - :class:`HumanRendering` - - Misc Wrapper - - Allows human like rendering for environments that support "rgb_array" rendering - * - :class:`NormalizeObservation` - - Observation Wrapper - - This wrapper will normalize observations s.t. each coordinate is centered with unit variance. - * - :class:`NormalizeReward` - - Reward Wrapper - - This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. - * - :class:`OrderEnforcing` - - Misc Wrapper - - This will produce an error if `step` or `render` is called before `reset` - * - :class:`PixelObservationWrapper` - - Observation Wrapper - - Augment observations by pixel values obtained via `render` that can be added to or replaces the environments observation. - * - :class:`RecordEpisodeStatistics` - - Misc Wrapper - - This will keep track of cumulative rewards and episode lengths returning them at the end. - * - :class:`RecordVideo` - - Misc Wrapper - - This wrapper will record videos of rollouts. - * - :class:`RenderCollection` - - Misc Wrapper - - Enable list versions of render modes, i.e. "rgb_array_list" for "rgb_array" such that the rendering for each step are saved in a list until `render` is called. - * - :class:`RescaleAction` - - Action Wrapper - - Rescales the continuous action space of the environment to a range \[`min_action`, `max_action`], where `min_action` and `max_action` are numpy arrays or floats. - * - :class:`ResizeObservation` - - Observation Wrapper - - This wrapper works on environments with image observations (or more generally observations of shape AxBxC) and resizes the observation to the shape given by the tuple `shape`. - * - :class:`StepAPICompatibility` - - Misc Wrapper - - Modifies an environment step function from (old) done to the (new) termination / truncation API. - * - :class:`TimeAwareObservation` - - Observation Wrapper - - Augment the observation with current time step in the trajectory (by appending it to the observation). - * - :class:`TimeLimit` - - Misc Wrapper - - This wrapper will emit a truncated signal if the specified number of steps is exceeded in an episode. - * - :class:`TransformObservation` - - Observation Wrapper - - This wrapper will apply function to observations - * - :class:`TransformReward` - - Reward Wrapper - - This wrapper will apply function to rewards - * - :class:`VectorListInfo` - - Misc Wrapper - - This wrapper will convert the info of a vectorized environment from the `dict` format to a `list` of dictionaries where the i-th dictionary contains info of the i-th environment. -``` diff --git a/docs/api/wrappers/action_wrappers.md b/docs/api/wrappers/action_wrappers.md index d93e8606a..87f777b1b 100644 --- a/docs/api/wrappers/action_wrappers.md +++ b/docs/api/wrappers/action_wrappers.md @@ -5,12 +5,14 @@ ```{eval-rst} .. autoclass:: gymnasium.ActionWrapper - .. automethod:: gymnasium.ActionWrapper.action + .. automethod:: gymnasium.ActionWrapper.action ``` ## Available Action Wrappers ```{eval-rst} +.. autoclass:: gymnasium.wrappers.TransformAction .. autoclass:: gymnasium.wrappers.ClipAction .. autoclass:: gymnasium.wrappers.RescaleAction +.. autoclass:: gymnasium.wrappers.StickyAction ``` diff --git a/docs/api/wrappers/misc_wrappers.md b/docs/api/wrappers/misc_wrappers.md index 3152d9d70..25deaafda 100644 --- a/docs/api/wrappers/misc_wrappers.md +++ b/docs/api/wrappers/misc_wrappers.md @@ -1,16 +1,33 @@ +--- +title: Misc Wrappers +--- + # Misc Wrappers + +## Common Wrappers + ```{eval-rst} +.. autoclass:: gymnasium.wrappers.TimeLimit +.. autoclass:: gymnasium.wrappers.RecordVideo +.. autoclass:: gymnasium.wrappers.RecordEpisodeStatistics .. autoclass:: gymnasium.wrappers.AtariPreprocessing -.. autoclass:: gymnasium.wrappers.AutoResetWrapper -.. autoclass:: gymnasium.wrappers.EnvCompatibility -.. autoclass:: gymnasium.wrappers.StepAPICompatibility +``` + +## Uncommon Wrappers + +```{eval-rst} +.. autoclass:: gymnasium.wrappers.Autoreset .. autoclass:: gymnasium.wrappers.PassiveEnvChecker .. autoclass:: gymnasium.wrappers.HumanRendering .. autoclass:: gymnasium.wrappers.OrderEnforcing -.. autoclass:: gymnasium.wrappers.RecordEpisodeStatistics -.. autoclass:: gymnasium.wrappers.RecordVideo .. autoclass:: gymnasium.wrappers.RenderCollection -.. autoclass:: gymnasium.wrappers.TimeLimit -.. autoclass:: gymnasium.wrappers.VectorListInfo +``` + +## Data Conversion Wrappers + +```{eval-rst} +.. autoclass:: gymnasium.wrappers.JaxToNumpy +.. autoclass:: gymnasium.wrappers.JaxToTorch +.. autoclass:: gymnasium.wrappers.NumpyToTorch ``` diff --git a/docs/api/wrappers/observation_wrappers.md b/docs/api/wrappers/observation_wrappers.md index 14238fc68..10284aca0 100644 --- a/docs/api/wrappers/observation_wrappers.md +++ b/docs/api/wrappers/observation_wrappers.md @@ -1,23 +1,26 @@ # Observation Wrappers -## Base Class - ```{eval-rst} .. autoclass:: gymnasium.ObservationWrapper .. automethod:: gymnasium.ObservationWrapper.observation ``` -## Available Observation Wrappers +## Implemented Wrappers ```{eval-rst} .. autoclass:: gymnasium.wrappers.TransformObservation +.. autoclass:: gymnasium.wrappers.DelayObservation +.. autoclass:: gymnasium.wrappers.DtypeObservation .. autoclass:: gymnasium.wrappers.FilterObservation .. autoclass:: gymnasium.wrappers.FlattenObservation -.. autoclass:: gymnasium.wrappers.FrameStack -.. autoclass:: gymnasium.wrappers.GrayScaleObservation +.. autoclass:: gymnasium.wrappers.FrameStackObservation +.. autoclass:: gymnasium.wrappers.GrayscaleObservation +.. autoclass:: gymnasium.wrappers.MaxAndSkipObservation .. autoclass:: gymnasium.wrappers.NormalizeObservation -.. autoclass:: gymnasium.wrappers.PixelObservationWrapper +.. autoclass:: gymnasium.wrappers.RenderObservation .. autoclass:: gymnasium.wrappers.ResizeObservation +.. autoclass:: gymnasium.wrappers.ReshapeObservation +.. autoclass:: gymnasium.wrappers.RescaleObservation .. autoclass:: gymnasium.wrappers.TimeAwareObservation ``` diff --git a/docs/api/wrappers/reward_wrappers.md b/docs/api/wrappers/reward_wrappers.md index 45d0476dd..209d4b839 100644 --- a/docs/api/wrappers/reward_wrappers.md +++ b/docs/api/wrappers/reward_wrappers.md @@ -1,17 +1,19 @@ +--- +title: Reward Wrappers +--- # Reward Wrappers -## Base Class - ```{eval-rst} .. autoclass:: gymnasium.RewardWrapper .. automethod:: gymnasium.RewardWrapper.reward ``` -## Available Reward Wrappers +## Implemented Wrappers ```{eval-rst} .. autoclass:: gymnasium.wrappers.TransformReward .. autoclass:: gymnasium.wrappers.NormalizeReward +.. autoclass:: gymnasium.wrappers.ClipReward ``` diff --git a/docs/api/wrappers/table.md b/docs/api/wrappers/table.md new file mode 100644 index 000000000..35de64fc9 --- /dev/null +++ b/docs/api/wrappers/table.md @@ -0,0 +1,102 @@ + +# List of Wrappers + +Gymnasium provides a number of commonly used wrappers listed below. More information can be found on the particular +wrapper in the page on the wrapper type + +```{eval-rst} +.. py:currentmodule:: gymnasium.wrappers + +.. list-table:: + :header-rows: 1 + + * - Name + - Description + * - :class:`AtariPreprocessing` + - Implements the common preprocessing techniques for Atari environments (excluding frame stacking). + * - :class:`Autoreset` + - The wrapped environment is automatically reset when an terminated or truncated state is reached. + * - :class:`ClipAction` + - Clips the ``action`` pass to ``step`` to be within the environment's `action_space`. + * - :class:`ClipReward` + - Clips the rewards for an environment between an upper and lower bound. + * - :class:`DelayObservation` + - Adds a delay to the returned observation from the environment. + * - :class:`DtypeObservation` + - Modifies the dtype of an observation array to a specified dtype. + * - :class:`FilterObservation` + - Filters a Dict or Tuple observation spaces by a set of keys or indexes. + * - :class:`FlattenObservation` + - Flattens the environment's observation space and each observation from ``reset`` and ``step`` functions. + * - :class:`FrameStackObservation` + - Stacks the observations from the last ``N`` time steps in a rolling manner. + * - :class:`GrayscaleObservation` + - Converts an image observation computed by ``reset`` and ``step`` from RGB to Grayscale. + * - :class:`HumanRendering` + - Allows human like rendering for environments that support "rgb_array" rendering. + * - :class:`JaxToNumpy` + - Wraps a Jax-based environment such that it can be interacted with NumPy arrays. + * - :class:`JaxToTorch` + - Wraps a Jax-based environment so that it can be interacted with PyTorch Tensors. + * - :class:`MaxAndSkipObservation` + - Skips the N-th frame (observation) and return the max values between the two last observations. + * - :class:`NormalizeObservation` + - Normalizes observations to be centered at the mean with unit variance. + * - :class:`NormalizeReward` + - Normalizes immediate rewards such that their exponential moving average has a fixed variance. + * - :class:`NumpyToTorch` + - Wraps a NumPy-based environment such that it can be interacted with PyTorch Tensors. + * - :class:`OrderEnforcing` + - Will produce an error if ``step`` or ``render`` is called before ``render``. + * - :class:`PassiveEnvChecker` + - A passive environment checker wrapper that surrounds the ``step``, ``reset`` and ``render`` functions to check they follows gymnasium's API. + * - :class:`RecordEpisodeStatistics` + - This wrapper will keep track of cumulative rewards and episode lengths. + * - :class:`RecordVideo` + - Records videos of environment episodes using the environment's render function. + * - :class:`RenderCollection` + - Collect rendered frames of an environment such ``render`` returns a ``list[RenderedFrame]``. + * - :class:`RenderObservation` + - Includes the rendered observations in the environment's observations. + * - :class:`RescaleAction` + - Affinely (linearly) rescales a ``Box`` action space of the environment to within the range of ``[min_action, max_action]``. + * - :class:`RescaleObservation` + - Affinely (linearly) rescales a ``Box`` observation space of the environment to within the range of ``[min_obs, max_obs]``. + * - :class:`ReshapeObservation` + - Reshapes Array based observations to a specified shape. + * - :class:`ResizeObservation` + - Resizes image observations using OpenCV to a specified shape. + * - :class:`StickyAction` + - Adds a probability that the action is repeated for the same ``step`` function. + * - :class:`TimeAwareObservation` + - Augment the observation with the number of time steps taken within an episode. + * - :class:`TimeLimit` + - Limits the number of steps for an environment through truncating the environment if a maximum number of timesteps is exceeded. + * - :class:`TransformAction` + - Applies a function to the ``action`` before passing the modified value to the environment ``step`` function. + * - :class:`TransformObservation` + - Applies a function to the ``observation`` received from the environment's ``reset`` and ``step`` that is passed back to the user. + * - :class:`TransformReward` + - Applies a function to the ``reward`` received from the environment's ``step``. + +``` + +## Vector only Wrappers + +```{eval-rst} +.. py:currentmodule:: gymnasium.wrappers.vector + +.. list-table:: + :header-rows: 1 + + * - Name + - Description + * - :class:`DictInfoToList` + - Converts infos of vectorized environments from ``dict`` to ``List[dict]``. + * - :class:`VectorizeTransformAction` + - Vectorizes a single-agent transform action wrapper for vector environments. + * - :class:`VectorizeTransformObservation` + - Vectorizes a single-agent transform observation wrapper for vector environments. + * - :class:`VectorizeTransformReward` + - Vectorizes a single-agent transform reward wrapper for vector environments. +``` diff --git a/docs/api/wrappers/vector_wrappers.md b/docs/api/wrappers/vector_wrappers.md new file mode 100644 index 000000000..e2636f642 --- /dev/null +++ b/docs/api/wrappers/vector_wrappers.md @@ -0,0 +1,19 @@ +--- +title: Vector Wrappers +--- + +# Vector wrappers + +## Vector only wrappers + +```{eval-rst} +.. autoclass:: gymnasium.wrappers.vector.DictInfoToList +``` + +## Vectorize Transform Wrappers to Vector Wrappers + +```{eval-rst} +.. autoclass:: gymnasium.wrappers.vector.VectorizeTransformObservation +.. autoclass:: gymnasium.wrappers.vector.VectorizeTransformAction +.. autoclass:: gymnasium.wrappers.vector.VectorizeTransformReward +``` diff --git a/docs/conf.py b/docs/conf.py index 1b5c465bb..5e0e31dbb 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -40,10 +40,10 @@ release = gymnasium.__version__ # ones. extensions = [ "sphinx.ext.napoleon", - "sphinx.ext.doctest", "sphinx.ext.autodoc", "sphinx.ext.githubpages", "sphinx.ext.viewcode", + "sphinx.ext.coverage", "myst_parser", "furo.gen_tutorials", "sphinx_gallery.gen_gallery", diff --git a/docs/content/basic_usage.md b/docs/content/basic_usage.md deleted file mode 100644 index 99216d23e..000000000 --- a/docs/content/basic_usage.md +++ /dev/null @@ -1,134 +0,0 @@ ---- -layout: "contents" -title: Basic Usage -firstpage: ---- - -# Basic Usage - -Gymnasium is a project that provides an API for all single agent reinforcement learning environments, and includes implementations of common environments: cartpole, pendulum, mountain-car, mujoco, atari, and more. - -The API contains four key functions: ``make``, ``reset``, ``step`` and ``render``, that this basic usage will introduce you to. At the core of Gymnasium is ``Env``, a high-level python class representing a markov decision process (MDP) from reinforcement learning theory (this is not a perfect reconstruction, and is missing several components of MDPs). Within gymnasium, environments (MDPs) are implemented as ``Env`` classes, along with ``Wrappers``, which provide helpful utilities and can change the results passed to the user. - -## Initializing Environments - -Initializing environments is very easy in Gymnasium and can be done via the ``make`` function: - -```python -import gymnasium as gym -env = gym.make('CartPole-v1') -``` - -This will return an ``Env`` for users to interact with. To see all environments you can create, use ``gymnasium.envs.registry.keys()``.``make`` includes a number of additional parameters to adding wrappers, specifying keywords to the environment and more. - -## Interacting with the Environment - -The classic "agent-environment loop" pictured below is simplified representation of reinforcement learning that Gymnasium implements. - -```{image} /_static/diagrams/AE_loop.png -:width: 50% -:align: center -:class: only-light -``` - -```{image} /_static/diagrams/AE_loop_dark.png -:width: 50% -:align: center -:class: only-dark -``` - -This loop is implemented using the following gymnasium code - -```python -import gymnasium as gym -env = gym.make("LunarLander-v2", render_mode="human") -observation, info = env.reset() - -for _ in range(1000): - action = env.action_space.sample() # agent policy that uses the observation and info - observation, reward, terminated, truncated, info = env.step(action) - - if terminated or truncated: - observation, info = env.reset() - -env.close() -``` - -The output should look something like this: - -```{figure} https://user-images.githubusercontent.com/15806078/153222406-af5ce6f0-4696-4a24-a683-46ad4939170c.gif -:width: 50% -:align: center -``` - -### Explaining the code - -First, an environment is created using ``make`` with an additional keyword `"render_mode"` that specifies how the environment should be visualised. See ``render`` for details on the default meaning of different render modes. In this example, we use the ``"LunarLander"`` environment where the agent controls a spaceship that needs to land safely. - -After initializing the environment, we ``reset`` the environment to get the first observation of the environment. For initializing the environment with a particular random seed or options (see environment documentation for possible values) use the ``seed`` or ``options`` parameters with ``reset``. - -Next, the agent performs an action in the environment, ``step``, this can be imagined as moving a robot or pressing a button on a games' controller that causes a change within the environment. As a result, the agent receives a new observation from the updated environment along with a reward for taking the action. This reward could be for instance positive for destroying an enemy or a negative reward for moving into lava. One such action-observation exchange is referred to as a *timestep*. - -However, after some timesteps, the environment may end, this is called the terminal state. For instance, the robot may have crashed, or the agent have succeeded in completing a task, the environment will need to stop as the agent cannot continue. In gymnasium, if the environment has terminated, this is returned by ``step``. Similarly, we may also want the environment to end after a fixed number of timesteps, in this case, the environment issues a truncated signal. If either of ``terminated`` or ``truncated`` are `true` then ``reset`` should be called next to restart the environment. - -## Action and observation spaces - -Every environment specifies the format of valid actions and observations with the ``env.action_space`` and ``env.observation_space`` attributes. This is helpful for both knowing the expected input and output of the environment as all valid actions and observation should be contained with the respective space. - -In the example, we sampled random actions via ``env.action_space.sample()`` instead of using an agent policy, mapping observations to actions which users will want to make. See one of the agent tutorials for an example of creating and training an agent policy. - -Every environment should have the attributes ``action_space`` and ``observation_space``, both of which should be instances of classes that inherit from ``Space``. Gymnasium has support for a majority of possible spaces users might need: - -- ``Box``: describes an n-dimensional continuous space. It's a bounded space where we can define the upper and lower - limits which describe the valid values our observations can take. -- ``Discrete``: describes a discrete space where {0, 1, ..., n-1} are the possible values our observation or action can take. - Values can be shifted to {a, a+1, ..., a+n-1} using an optional argument. -- ``Dict``: represents a dictionary of simple spaces. -- ``Tuple``: represents a tuple of simple spaces. -- ``MultiBinary``: creates an n-shape binary space. Argument n can be a number or a list of numbers. -- ``MultiDiscrete``: consists of a series of ``Discrete`` action spaces with a different number of actions in each element. - -For example usage of spaces, see their [documentation](/api/spaces) along with [utility functions](/api/spaces/utils). There are a couple of more niche spaces ``Graph``, ``Sequence`` and ``Text``. - -## Modifying the environment - -Wrappers are a convenient way to modify an existing environment without having to alter the underlying code directly. Using wrappers will allow you to avoid a lot of boilerplate code and make your environment more modular. Wrappers can also be chained to combine their effects. Most environments that are generated via ``gymnasium.make`` will already be wrapped by default using the ``TimeLimit``, ``OrderEnforcing`` and ``PassiveEnvChecker``. - -In order to wrap an environment, you must first initialize a base environment. Then you can pass this environment along with (possibly optional) parameters to the wrapper's constructor: - -```python ->>> import gymnasium as gym ->>> from gymnasium.wrappers import FlattenObservation ->>> env = gym.make("CarRacing-v2") ->>> env.observation_space.shape -(96, 96, 3) ->>> wrapped_env = FlattenObservation(env) ->>> wrapped_env.observation_space.shape -(27648,) - -``` - -Gymnasium already provides many commonly used wrappers for you. Some examples: - -- `TimeLimit`: Issue a truncated signal if a maximum number of timesteps has been exceeded (or the base environment has issued a truncated signal). -- `ClipAction`: Clip the action such that it lies in the action space (of type `Box`). -- `RescaleAction`: Rescale actions to lie in a specified interval -- `TimeAwareObservation`: Add information about the index of timestep to observation. In some cases helpful to ensure that transitions are Markov. - -For a full list of implemented wrappers in gymnasium, see [wrappers](/api/wrappers). - -If you have a wrapped environment, and you want to get the unwrapped environment underneath all the layers of wrappers (so that you can manually call a function or change some underlying aspect of the environment), you can use the `.unwrapped` attribute. If the environment is already a base environment, the `.unwrapped` attribute will just return itself. - -```python ->>> wrapped_env ->>>>> ->>> wrapped_env.unwrapped - - -``` - -## More information - -* [Making a Custom environment using the Gymnasium API](/tutorials/gymnasium_basics/environment_creation/) -* [Training an agent to play blackjack](/tutorials/training_agents/blackjack_tutorial) -* [Compatibility with OpenAI Gym](/content/gym_compatibility) diff --git a/docs/content/migration-guide.md b/docs/content/migration-guide.md deleted file mode 100644 index b9b8bf530..000000000 --- a/docs/content/migration-guide.md +++ /dev/null @@ -1,122 +0,0 @@ ---- -layout: "contents" -title: Migration Guide ---- - -# v21 to v26 Migration Guide - -```{eval-rst} -.. py:currentmodule:: gymnasium.wrappers - -Gymnasium is a fork of `OpenAI Gym v26 `_, which introduced a large breaking change from `Gym v21 `_. -In this guide, we briefly outline the API changes from Gym v21 - which a number of tutorials have been written for - to Gym v26. -For environments still stuck in the v21 API, users can use the :class:`EnvCompatibility` wrapper to convert them to v26 compliant. -For more information, see the `guide `_ -``` - -### Example code for v21 - -```python -import gym -env = gym.make("LunarLander-v2", options={}) -env.seed(123) -observation = env.reset() - -done = False -while not done: - action = env.action_space.sample() # agent policy that uses the observation and info - observation, reward, done, info = env.step(action) - - env.render(mode="human") - -env.close() -``` - -### Example code for v26 - -```python -import gym -env = gym.make("LunarLander-v2", render_mode="human") -observation, info = env.reset(seed=123, options={}) - -done = False -while not done: - action = env.action_space.sample() # agent policy that uses the observation and info - observation, reward, terminated, truncated, info = env.step(action) - - done = terminated or truncated - -env.close() -``` - -## Seed and random number generator - -```{eval-rst} -.. py:currentmodule:: gymnasium.Env - -The ``Env.seed()`` has been removed from the Gym v26 environments in favour of ``Env.reset(seed=seed)``. -This allows seeding to only be changed on environment reset. -The decision to remove ``seed`` was because some environments use emulators that cannot change random number generators within an episode and must be done at the beginning of a new episode. -We are aware of cases where controlling the random number generator is important, in these cases, if the environment uses the built-in random number generator, users can set the seed manually with the attribute :attr:`np_random`. - -Gymnasium v26 changed to using ``numpy.random.Generator`` instead of a custom random number generator. -This means that several functions such as ``randint`` were removed in favour of ``integers``. -While some environments might use external random number generator, we recommend using the attribute :attr:`np_random` that wrappers and external users can access and utilise. -``` - -## Environment Reset - -```{eval-rst} -In v26, :meth:`reset` takes two optional parameters and returns one value. -This contrasts to v21 which takes no parameters and returns ``None``. -The two parameters are ``seed`` for setting the random number generator and ``options`` which allows additional data to be passed to the environment on reset. -For example, in classic control, the ``options`` parameter now allows users to modify the range of the state bound. -See the original `PR `_ for more details. - -:meth:`reset` further returns ``info``, similar to the ``info`` returned by :meth:`step`. -This is important because ``info`` can include metrics or valid action mask that is used or saved in the next step. - -To update older environments, we highly recommend that ``super().reset(seed=seed)`` is called on the first line of :meth:`reset`. -This will automatically update the :attr:`np_random` with the seed value. -``` - -## Environment Step - -```{eval-rst} -In v21, the type definition of :meth:`step` is ``tuple[ObsType, SupportsFloat, bool, dict[str, Any]`` representing the next observation, the reward from the step, if the episode is done and additional info from the step. -Due to reproducibility issues that will be expanded on in a blog post soon, we have changed the type definition to ``tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]`` adding an extra boolean value. -This extra bool corresponds to the older `done` now changed to `terminated` and `truncated`. -These changes were introduced in Gym `v26 `_ (turned off by default in `v25 `_). - -For users wishing to update, in most cases, replacing ``done`` with ``terminated`` and ``truncated=False`` in :meth:`step` should address most issues. -However, environments that have reasons for episode truncation rather than termination should read through the associated `PR `_. -For users looping through an environment, they should modify ``done = terminated or truncated`` as is show in the example code. -For training libraries, the primary difference is to change ``done`` to ``terminated``, indicating whether bootstrapping should or shouldn't happen. -``` - -## TimeLimit Wrapper -```{eval-rst} -In v21, the :class:`TimeLimit` wrapper added an extra key in the ``info`` dictionary ``TimeLimit.truncated`` whenever the agent reached the time limit without reaching a terminal state. - -In v26, this information is instead communicated through the `truncated` return value described in the previous section, which is `True` if the agent reaches the time limit, whether or not it reaches a terminal state. The old dictionary entry is equivalent to ``truncated and not terminated`` -``` - -## Environment Render - -```{eval-rst} -In v26, a new render API was introduced such that the render mode is fixed at initialisation as some environments don't allow on-the-fly render mode changes. Therefore, users should now specify the :attr:`render_mode` within ``gym.make`` as shown in the v26 example code above. - -For a more complete explanation of the changes, please refer to this `summary `_. -``` - -## Removed code - -```{eval-rst} -.. py:currentmodule:: gymnasium.wrappers - -* GoalEnv - This was removed, users needing it should reimplement the environment or use Gymnasium Robotics which contains an implementation of this environment. -* ``from gym.envs.classic_control import rendering`` - This was removed in favour of users implementing their own rendering systems. Gymnasium environments are coded using pygame. -* Robotics environments - The robotics environments have been moved to the `Gymnasium Robotics `_ project. -* Monitor wrapper - This wrapper was replaced with two separate wrapper, :class:`RecordVideo` and :class:`RecordEpisodeStatistics` - -``` diff --git a/docs/index.md b/docs/index.md index a0a2d566f..1a66c2224 100644 --- a/docs/index.md +++ b/docs/index.md @@ -17,19 +17,27 @@ An API standard for reinforcement learning with a diverse collection of referenc :width: 500 ``` -**Gymnasium is a maintained fork of OpenAI’s Gym library.** The Gymnasium interface is simple, pythonic, and capable of representing general RL problems, and has a [compatibility wrapper](content/gym_compatibility) for old Gym environments: +**Gymnasium is a maintained fork of OpenAI’s Gym library.** The Gymnasium interface is simple, pythonic, and capable of representing general RL problems, and has a [compatibility wrapper](introduction/gym_compatibility) for old Gym environments: ```{code-block} python - import gymnasium as gym + +# Initialise the environment env = gym.make("LunarLander-v2", render_mode="human") + +# Reset the environment to generate the first observation observation, info = env.reset(seed=42) for _ in range(1000): - action = env.action_space.sample() # this is where you would insert your policy - observation, reward, terminated, truncated, info = env.step(action) + # this is where you would insert your policy + action = env.action_space.sample() - if terminated or truncated: - observation, info = env.reset() + # step (transition) through the environment with the action + # receiving the next observation, reward and if the episode has terminated or truncated + observation, reward, terminated, truncated, info = env.step(action) + + # If the episode has ended then we can reset to start a new episode + if terminated or truncated: + observation, info = env.reset() env.close() ``` @@ -38,9 +46,9 @@ env.close() :hidden: :caption: Introduction -content/basic_usage -content/gym_compatibility -content/migration-guide +introduction/basic_usage +introduction/gym_compatibility +introduction/migration-guide ``` ```{toctree} @@ -53,7 +61,7 @@ api/spaces api/wrappers api/vector api/utils -api/experimental +api/functional ``` ```{toctree} diff --git a/docs/introduction/basic_usage.md b/docs/introduction/basic_usage.md new file mode 100644 index 000000000..77319eb9c --- /dev/null +++ b/docs/introduction/basic_usage.md @@ -0,0 +1,172 @@ +--- +layout: "contents" +title: Basic Usage +firstpage: +--- + +# Basic Usage + +```{eval-rst} +.. py:currentmodule:: gymnasium + +Gymnasium is a project that provides an API for all single agent reinforcement learning environments, and includes implementations of common environments: cartpole, pendulum, mountain-car, mujoco, atari, and more. + +The API contains four key functions: :meth:`make`, :meth:`Env.reset`, :meth:`Env.step` and :meth:`Env.render`, that this basic usage will introduce you to. At the core of Gymnasium is :class:`Env`, a high-level python class representing a markov decision process (MDP) from reinforcement learning theory (this is not a perfect reconstruction, and is missing several components of MDPs). Within gymnasium, environments (MDPs) are implemented as :class:`Env` classes, along with :class:`Wrapper`, provide helpful utilities to change actions passed to the environment and modified the observations, rewards, termination or truncations conditions passed back to the user. +``` + +## Initializing Environments + +```{eval-rst} +.. py:currentmodule:: gymnasium + +Initializing environments is very easy in Gymnasium and can be done via the :meth:`make` function: +``` + +```python +import gymnasium as gym +env = gym.make('CartPole-v1') +``` + +```{eval-rst} +.. py:currentmodule:: gymnasium + +This will return an :class:`Env` for users to interact with. To see all environments you can create, use :meth:`pprint_registry`. Furthermore, :meth:`make` provides a number of additional arguments for specifying keywords to the environment, adding more or less wrappers, etc. +``` + +## Interacting with the Environment + +The classic "agent-environment loop" pictured below is simplified representation of reinforcement learning that Gymnasium implements. + +```{image} /_static/diagrams/AE_loop.png +:width: 50% +:align: center +:class: only-light +``` + +```{image} /_static/diagrams/AE_loop_dark.png +:width: 50% +:align: center +:class: only-dark +``` + +This loop is implemented using the following gymnasium code + +```python +import gymnasium as gym +env = gym.make("LunarLander-v2", render_mode="human") +observation, info = env.reset() + +for _ in range(1000): + action = env.action_space.sample() # agent policy that uses the observation and info + observation, reward, terminated, truncated, info = env.step(action) + + if terminated or truncated: + observation, info = env.reset() + +env.close() +``` + +The output should look something like this: + +```{figure} https://user-images.githubusercontent.com/15806078/153222406-af5ce6f0-4696-4a24-a683-46ad4939170c.gif +:width: 50% +:align: center +``` + +### Explaining the code + +```{eval-rst} +.. py:currentmodule:: gymnasium + +First, an environment is created using :meth:`make` with an additional keyword ``"render_mode"`` that specifies how the environment should be visualised. + +.. py:currentmodule:: gymnasium.Env + +See :meth:`render` for details on the default meaning of different render modes. In this example, we use the ``"LunarLander"`` environment where the agent controls a spaceship that needs to land safely. + +After initializing the environment, we :meth:`reset` the environment to get the first observation of the environment. For initializing the environment with a particular random seed or options (see environment documentation for possible values) use the ``seed`` or ``options`` parameters with :meth:`reset`. + +Next, the agent performs an action in the environment, :meth:`step`, this can be imagined as moving a robot or pressing a button on a games' controller that causes a change within the environment. As a result, the agent receives a new observation from the updated environment along with a reward for taking the action. This reward could be for instance positive for destroying an enemy or a negative reward for moving into lava. One such action-observation exchange is referred to as a **timestep**. + +However, after some timesteps, the environment may end, this is called the terminal state. For instance, the robot may have crashed, or the agent have succeeded in completing a task, the environment will need to stop as the agent cannot continue. In gymnasium, if the environment has terminated, this is returned by :meth:`step`. Similarly, we may also want the environment to end after a fixed number of timesteps, in this case, the environment issues a truncated signal. If either of ``terminated`` or ``truncated`` are ``True`` then :meth:`reset` should be called next to restart the environment. +``` + +## Action and observation spaces + +```{eval-rst} +.. py:currentmodule:: gymnasium.Env + +Every environment specifies the format of valid actions and observations with the :attr:`action_space` and :attr:`observation_space` attributes. This is helpful for both knowing the expected input and output of the environment as all valid actions and observation should be contained with the respective space. + +In the example, we sampled random actions via ``env.action_space.sample()`` instead of using an agent policy, mapping observations to actions which users will want to make. See one of the agent tutorials for an example of creating and training an agent policy. + +.. py:currentmodule:: gymnasium + +Every environment should have the attributes :attr:`Env.action_space` and :attr:`Env.observation_space`, both of which should be instances of classes that inherit from :class:`spaces.Space`. Gymnasium has support for a majority of possible spaces users might need: + +.. py:currentmodule:: gymnasium.spaces + +- :class:`Box`: describes an n-dimensional continuous space. It's a bounded space where we can define the upper and lower + limits which describe the valid values our observations can take. +- :class:`Discrete`: describes a discrete space where ``{0, 1, ..., n-1}`` are the possible values our observation or action can take. + Values can be shifted to ``{a, a+1, ..., a+n-1}`` using an optional argument. +- :class:`Dict`: represents a dictionary of simple spaces. +- :class:`Tuple`: represents a tuple of simple spaces. +- :class:`MultiBinary`: creates an n-shape binary space. Argument n can be a number or a list of numbers. +- :class:`MultiDiscrete`: consists of a series of :class:`Discrete` action spaces with a different number of actions in each element. + +For example usage of spaces, see their `documentation `_ along with `utility functions `_. There are a couple of more niche spaces :class:`Graph`, :class:`Sequence` and :class:`Text`. +``` + +## Modifying the environment + +```{eval-rst} +.. py:currentmodule:: gymnasium.wrappers + +Wrappers are a convenient way to modify an existing environment without having to alter the underlying code directly. Using wrappers will allow you to avoid a lot of boilerplate code and make your environment more modular. Wrappers can also be chained to combine their effects. Most environments that are generated via ``gymnasium.make`` will already be wrapped by default using the :class:`TimeLimitV0`, :class:`OrderEnforcingV0` and :class:`PassiveEnvCheckerV0`. + +In order to wrap an environment, you must first initialize a base environment. Then you can pass this environment along with (possibly optional) parameters to the wrapper's constructor: +``` + +```python +>>> import gymnasium as gym +>>> from gymnasium.wrappers import FlattenObservation +>>> env = gym.make("CarRacing-v2") +>>> env.observation_space.shape +(96, 96, 3) +>>> wrapped_env = FlattenObservation(env) +>>> wrapped_env.observation_space.shape +(27648,) +``` + +```{eval-rst} +.. py:currentmodule:: gymnasium.wrappers + +Gymnasium already provides many commonly used wrappers for you. Some examples: + +- :class:`TimeLimitV0`: Issue a truncated signal if a maximum number of timesteps has been exceeded (or the base environment has issued a truncated signal). +- :class:`ClipActionV0`: Clip the action such that it lies in the action space (of type `Box`). +- :class:`RescaleActionV0`: Rescale actions to lie in a specified interval +- :class:`TimeAwareObservationV0`: Add information about the index of timestep to observation. In some cases helpful to ensure that transitions are Markov. +``` + +For a full list of implemented wrappers in gymnasium, see [wrappers](/api/wrappers). + +```{eval-rst} +.. py:currentmodule:: gymnasium.Env + +If you have a wrapped environment, and you want to get the unwrapped environment underneath all the layers of wrappers (so that you can manually call a function or change some underlying aspect of the environment), you can use the :attr:`unwrapped` attribute. If the environment is already a base environment, the :attr:`unwrapped` attribute will just return itself. +``` + +```python +>>> wrapped_env +>>>>> +>>> wrapped_env.unwrapped + +``` + +## More information + +* [Making a Custom environment using the Gymnasium API](/tutorials/gymnasium_basics/environment_creation/) +* [Training an agent to play blackjack](/tutorials/training_agents/blackjack_tutorial) +* [Compatibility with OpenAI Gym](/introduction/gym_compatibility) diff --git a/docs/content/gym_compatibility.md b/docs/introduction/gym_compatibility.md similarity index 65% rename from docs/content/gym_compatibility.md rename to docs/introduction/gym_compatibility.md index 8be75bdd6..abd75eb85 100644 --- a/docs/content/gym_compatibility.md +++ b/docs/introduction/gym_compatibility.md @@ -12,9 +12,7 @@ Gymnasium provides a number of compatibility methods for a range of Environment ```{eval-rst} .. py:currentmodule:: gymnasium.wrappers -For environments that are registered solely in OpenAI Gym and not in Gymnasium, Gymnasium v0.26.3 and above allows importing them through either a special environment or a wrapper. -The ``"GymV26Environment-v0"`` environment was introduced in Gymnasium v0.26.3, and allows importing of Gym environments through the ``env_name`` argument along with other relevant kwargs environment kwargs. -To perform conversion through a wrapper, the environment itself can be passed to the wrapper :class:`EnvCompatibility` through the ``env`` kwarg. +For environments that are registered solely in OpenAI Gym and not in Gymnasium, Gymnasium v0.26.3 and above allows importing them through either a special environment or a wrapper. The ``"GymV26Environment-v0"`` environment was introduced in Gymnasium v0.26.3, and allows importing of Gym environments through the ``env_name`` argument along with other relevant kwargs environment kwargs. To perform conversion through a wrapper, the environment itself can be passed to the wrapper :class:`EnvCompatibility` through the ``env`` kwarg. ``` An example of this is atari 0.8.0 which does not have a gymnasium implementation. @@ -29,9 +27,7 @@ env = gym.make("GymV26Environment-v0", env_id="ALE/Pong-v5") ```{eval-rst} .. py:currentmodule:: gymnasium -A number of environments have not updated to the recent Gym changes, in particular since v0.21. -This update is significant for the introduction of ``termination`` and ``truncation`` signatures in favour of the previously used ``done``. -To allow backward compatibility, Gym and Gymnasium v0.26+ include an ``apply_api_compatibility`` kwarg when calling :meth:`make` that automatically converts a v0.21 API compliant environment to one that is compatible with v0.26+. +A number of environments have not updated to the recent Gym changes, in particular since v0.21. This update is significant for the introduction of ``termination`` and ``truncation`` signatures in favour of the previously used ``done``. To allow backward compatibility, Gym and Gymnasium v0.26+ include an ``apply_api_compatibility`` kwarg when calling :meth:`make` that automatically converts a v0.21 API compliant environment to one that is compatible with v0.26+. ``` ```python diff --git a/docs/introduction/migration-guide.md b/docs/introduction/migration-guide.md new file mode 100644 index 000000000..2484c5ec1 --- /dev/null +++ b/docs/introduction/migration-guide.md @@ -0,0 +1,103 @@ +--- +layout: "contents" +title: Migration Guide +--- + +# v0.21 to v0.26 Migration Guide + +```{eval-rst} +.. py:currentmodule:: gymnasium.wrappers + +Gymnasium is a fork of `OpenAI Gym v0.26 `_, which introduced a large breaking change from `Gym v0.21 `_. In this guide, we briefly outline the API changes from Gym v0.21 - which a number of tutorials have been written for - to Gym v0.26. For environments still stuck in the v0.21 API, users can use the :class:`EnvCompatibility` wrapper to convert them to v0.26 compliant. +For more information, see the `guide `_ +``` + +## Example code for v0.21 + +```python +import gym +env = gym.make("LunarLander-v2", options={}) +env.seed(123) +observation = env.reset() + +done = False +while not done: + action = env.action_space.sample() # agent policy that uses the observation and info + observation, reward, done, info = env.step(action) + + env.render(mode="human") + +env.close() +``` + +## Example code for v0.26 + +```python +import gym +env = gym.make("LunarLander-v2", render_mode="human") +observation, info = env.reset(seed=123, options={}) + +done = False +while not done: + action = env.action_space.sample() # agent policy that uses the observation and info + observation, reward, terminated, truncated, info = env.step(action) + + done = terminated or truncated + +env.close() +``` + +## Seed and random number generator + +```{eval-rst} +.. py:currentmodule:: gymnasium.Env + +The ``Env.seed()`` has been removed from the Gym v0.26 environments in favour of ``Env.reset(seed=seed)``. This allows seeding to only be changed on environment reset. The decision to remove ``seed`` was because some environments use emulators that cannot change random number generators within an episode and must be done at the beginning of a new episode. We are aware of cases where controlling the random number generator is important, in these cases, if the environment uses the built-in random number generator, users can set the seed manually with the attribute :attr:`np_random`. + +Gymnasium v0.26 changed to using ``numpy.random.Generator`` instead of a custom random number generator. This means that several functions such as ``randint`` were removed in favour of ``integers``. While some environments might use external random number generator, we recommend using the attribute :attr:`np_random` that wrappers and external users can access and utilise. +``` + +## Environment Reset + +```{eval-rst} +In v0.26, :meth:`reset` takes two optional parameters and returns one value. This contrasts to v0.21 which takes no parameters and returns ``None``. The two parameters are ``seed`` for setting the random number generator and ``options`` which allows additional data to be passed to the environment on reset. For example, in classic control, the ``options`` parameter now allows users to modify the range of the state bound. See the original `PR `_ for more details. + +:meth:`reset` further returns ``info``, similar to the ``info`` returned by :meth:`step`. This is important because ``info`` can include metrics or valid action mask that is used or saved in the next step. + +To update older environments, we highly recommend that ``super().reset(seed=seed)`` is called on the first line of :meth:`reset`. This will automatically update the :attr:`np_random` with the seed value. +``` + +## Environment Step + +```{eval-rst} +In v0.21, the type definition of :meth:`step` is ``tuple[ObsType, SupportsFloat, bool, dict[str, Any]`` representing the next observation, the reward from the step, if the episode is done and additional info from the step. Due to reproducibility issues that will be expanded on in a blog post soon, we have changed the type definition to ``tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]`` adding an extra boolean value. This extra bool corresponds to the older `done` now changed to `terminated` and `truncated`. These changes were introduced in Gym `v0.26 `_ (turned off by default in `v25 `_). + +For users wishing to update, in most cases, replacing ``done`` with ``terminated`` and ``truncated=False`` in :meth:`step` should address most issues. However, environments that have reasons for episode truncation rather than termination should read through the associated `PR `_. For users looping through an environment, they should modify ``done = terminated or truncated`` as is show in the example code. For training libraries, the primary difference is to change ``done`` to ``terminated``, indicating whether bootstrapping should or shouldn't happen. +``` + +## TimeLimit Wrapper +```{eval-rst} +In v0.21, the :class:`TimeLimit` wrapper added an extra key in the ``info`` dictionary ``TimeLimit.truncated`` whenever the agent reached the time limit without reaching a terminal state. + +In v0.26, this information is instead communicated through the `truncated` return value described in the previous section, which is `True` if the agent reaches the time limit, whether or not it reaches a terminal state. The old dictionary entry is equivalent to ``truncated and not terminated`` +``` + +## Environment Render + +```{eval-rst} +In v0.26, a new render API was introduced such that the render mode is fixed at initialisation as some environments don't allow on-the-fly render mode changes. Therefore, users should now specify the :attr:`render_mode` within ``gym.make`` as shown in the v0.26 example code above. + +For a more complete explanation of the changes, please refer to this `summary `_. +``` + +## Removed code + +```{eval-rst} +.. py:currentmodule:: gymnasium.wrappers + +* GoalEnv - This was removed, users needing it should reimplement the environment or use Gymnasium Robotics which contains an implementation of this environment. +* ``from gym.envs.classic_control import rendering`` - This was removed in favour of users implementing their own rendering systems. Gymnasium environments are coded using pygame. +* Robotics environments - The robotics environments have been moved to the `Gymnasium Robotics `_ project. +* Monitor wrapper - This wrapper was replaced with two separate wrapper, :class:`RecordVideo` and :class:`RecordEpisodeStatistics` + +``` diff --git a/gymnasium/__init__.py b/gymnasium/__init__.py index 5e6c6f85f..106f7c4b9 100644 --- a/gymnasium/__init__.py +++ b/gymnasium/__init__.py @@ -21,8 +21,7 @@ from gymnasium.envs.registration import ( # necessary for `envs.__init__` which registers all gymnasium environments and loads plugins from gymnasium import envs -from gymnasium import spaces, utils, vector, wrappers, error, logger -from gymnasium import experimental +from gymnasium import spaces, utils, vector, wrappers, error, logger, functional __all__ = [ @@ -43,15 +42,15 @@ __all__ = [ "register_envs", # module folders "envs", - "experimental", "spaces", "utils", "vector", "wrappers", "error", "logger", + "functional", ] -__version__ = "0.29.0" +__version__ = "1.0.0rc1" # Initializing pygame initializes audio connections through SDL. SDL uses alsa by default on all Linux systems diff --git a/gymnasium/core.py b/gymnasium/core.py index 7641f8e65..bf8e87dc6 100644 --- a/gymnasium/core.py +++ b/gymnasium/core.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Generic, SupportsFloat, TypeVar import numpy as np -from gymnasium import logger, spaces +from gymnasium import spaces from gymnasium.utils import RecordConstructorArgs, seeding @@ -37,12 +37,10 @@ class Env(Generic[ObsType, ActType]): - :attr:`action_space` - The Space object corresponding to valid actions, all valid actions should be contained within the space. - :attr:`observation_space` - The Space object corresponding to valid observations, all valid observations should be contained within the space. - - :attr:`reward_range` - A tuple corresponding to the minimum and maximum possible rewards for an agent over an episode. - The default reward range is set to :math:`(-\infty,+\infty)`. - :attr:`spec` - An environment spec that contains the information used to initialize the environment from :meth:`gymnasium.make` - :attr:`metadata` - The metadata of the environment, i.e. render modes, render fps - :attr:`np_random` - The random number generator for the environment. This is automatically assigned during - ``super().reset(seed=seed)`` and when assessing ``self.np_random``. + ``super().reset(seed=seed)`` and when assessing :attr:`np_random`. .. seealso:: For modifying or extending environments use the :py:class:`gymnasium.Wrapper` class @@ -54,7 +52,6 @@ class Env(Generic[ObsType, ActType]): metadata: dict[str, Any] = {"render_modes": []} # define render_mode if your environment supports rendering render_mode: str | None = None - reward_range = (-float("inf"), float("inf")) spec: EnvSpec | None = None # Set these in ALL subclasses @@ -238,6 +235,10 @@ class Env(Generic[ObsType, ActType]): """Gets the attribute `name` from the environment.""" return getattr(self, name) + def set_wrapper_attr(self, name: str, value: Any): + """Sets the attribute `name` on the environment with `value`.""" + setattr(self, name, value) + WrapperObsType = TypeVar("WrapperObsType") WrapperActType = TypeVar("WrapperActType") @@ -268,56 +269,41 @@ class Wrapper( env: The environment to wrap """ self.env = env + assert isinstance(env, Env) self._action_space: spaces.Space[WrapperActType] | None = None self._observation_space: spaces.Space[WrapperObsType] | None = None - self._reward_range: tuple[SupportsFloat, SupportsFloat] | None = None self._metadata: dict[str, Any] | None = None self._cached_spec: EnvSpec | None = None - def __getattr__(self, name: str) -> Any: - """Returns an attribute with ``name``, unless ``name`` starts with an underscore. + def step( + self, action: WrapperActType + ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]: + """Uses the :meth:`step` of the :attr:`env` that can be overwritten to change the returned data.""" + return self.env.step(action) - Args: - name: The variable name + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[WrapperObsType, dict[str, Any]]: + """Uses the :meth:`reset` of the :attr:`env` that can be overwritten to change the returned data.""" + return self.env.reset(seed=seed, options=options) - Returns: - The value of the variable in the wrapper stack + def render(self) -> RenderFrame | list[RenderFrame] | None: + """Uses the :meth:`render` of the :attr:`env` that can be overwritten to change the returned data.""" + return self.env.render() - Warnings: - This feature is deprecated and removed in v1.0 and replaced with `env.get_attr(name})` + def close(self): + """Closes the wrapper and :attr:`env`.""" + return self.env.close() + + @property + def unwrapped(self) -> Env[ObsType, ActType]: + """Returns the base environment of the wrapper. + + This will be the bare :class:`gymnasium.Env` environment, underneath all layers of wrappers. """ - if name == "_np_random": - raise AttributeError( - "Can't access `_np_random` of a wrapper, use `self.unwrapped._np_random` or `self.np_random`." - ) - elif name.startswith("_"): - raise AttributeError(f"accessing private attribute '{name}' is prohibited") - logger.warn( - f"env.{name} to get variables from other wrappers is deprecated and will be removed in v1.0, " - f"to get this variable you can do `env.unwrapped.{name}` for environment variables or `env.get_attr('{name}')` that will search the remaining wrappers." - ) - return getattr(self.env, name) - - def get_wrapper_attr(self, name: str) -> Any: - """Gets an attribute from the wrapper and lower environments if `name` doesn't exist in this object. - - Args: - name: The variable name to get - - Returns: - The variable with name in wrapper or lower environments - """ - if hasattr(self, name): - return getattr(self, name) - else: - try: - return self.env.get_wrapper_attr(name) - except AttributeError as e: - raise AttributeError( - f"wrapper {self.class_name()} has no attribute {name!r}" - ) from e + return self.env.unwrapped @property def spec(self) -> EnvSpec | None: @@ -362,6 +348,53 @@ class Wrapper( kwargs=kwargs, ) + def get_wrapper_attr(self, name: str) -> Any: + """Gets an attribute from the wrapper and lower environments if `name` doesn't exist in this object. + + Args: + name: The variable name to get + + Returns: + The variable with name in wrapper or lower environments + """ + if hasattr(self, name): + return getattr(self, name) + else: + try: + return self.env.get_wrapper_attr(name) + except AttributeError as e: + raise AttributeError( + f"wrapper {self.class_name()} has no attribute {name!r}" + ) from e + + def set_wrapper_attr(self, name: str, value: Any): + """Sets an attribute on this wrapper or lower environment if `name` is already defined. + + Args: + name: The variable name + value: The new variable value + """ + sub_env = self.env + attr_set = False + + while attr_set is False and isinstance(sub_env, Wrapper): + if hasattr(sub_env, name): + setattr(sub_env, name, value) + attr_set = True + else: + sub_env = sub_env.env + + if attr_set is False: + setattr(sub_env, name, value) + + def __str__(self): + """Returns the wrapper name and the :attr:`env` representation string.""" + return f"<{type(self).__name__}{self.env}>" + + def __repr__(self): + """Returns the string representation of the wrapper.""" + return str(self) + @classmethod def class_name(cls) -> str: """Returns the class name of the wrapper.""" @@ -393,18 +426,6 @@ class Wrapper( def observation_space(self, space: spaces.Space[WrapperObsType]): self._observation_space = space - @property - def reward_range(self) -> tuple[SupportsFloat, SupportsFloat]: - """Return the :attr:`Env` :attr:`reward_range` unless overwritten then the wrapper :attr:`reward_range` is used.""" - if self._reward_range is None: - return self.env.reward_range - logger.warn("The `reward_range` is deprecated and will be removed in v1.0") - return self._reward_range - - @reward_range.setter - def reward_range(self, value: tuple[SupportsFloat, SupportsFloat]): - self._reward_range = value - @property def metadata(self) -> dict[str, Any]: """Returns the :attr:`Env` :attr:`metadata`.""" @@ -440,54 +461,15 @@ class Wrapper( "Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`." ) - def step( - self, action: WrapperActType - ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]: - """Uses the :meth:`step` of the :attr:`env` that can be overwritten to change the returned data.""" - return self.env.step(action) - - def reset( - self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[WrapperObsType, dict[str, Any]]: - """Uses the :meth:`reset` of the :attr:`env` that can be overwritten to change the returned data.""" - return self.env.reset(seed=seed, options=options) - - def render(self) -> RenderFrame | list[RenderFrame] | None: - """Uses the :meth:`render` of the :attr:`env` that can be overwritten to change the returned data.""" - return self.env.render() - - def close(self): - """Closes the wrapper and :attr:`env`.""" - return self.env.close() - - def __str__(self): - """Returns the wrapper name and the :attr:`env` representation string.""" - return f"<{type(self).__name__}{self.env}>" - - def __repr__(self): - """Returns the string representation of the wrapper.""" - return str(self) - - @property - def unwrapped(self) -> Env[ObsType, ActType]: - """Returns the base environment of the wrapper. - - This will be the bare :class:`gymnasium.Env` environment, underneath all layers of wrappers. - """ - return self.env.unwrapped - class ObservationWrapper(Wrapper[WrapperObsType, ActType, ObsType, ActType]): - """Superclass of wrappers that can modify observations using :meth:`observation` for :meth:`reset` and :meth:`step`. + """Modify observations from :meth:`Env.reset` and :meth:`Env.step` using :meth:`observation` function. If you would like to apply a function to only the observation before passing it to the learning code, you can simply inherit from :class:`ObservationWrapper` and overwrite the method :meth:`observation` to implement that transformation. The transformation defined in that method must be reflected by the :attr:`env` observation space. Otherwise, you need to specify the new observation space of the wrapper by setting :attr:`self.observation_space` in the :meth:`__init__` method of your wrapper. - - Among others, Gymnasium provides the observation wrapper :class:`TimeAwareObservation`, which adds information about the - index of the timestep to the observation. """ def __init__(self, env: Env[ObsType, ActType]): diff --git a/gymnasium/envs/__init__.py b/gymnasium/envs/__init__.py index e28a0cb35..507d075b4 100644 --- a/gymnasium/envs/__init__.py +++ b/gymnasium/envs/__init__.py @@ -1,14 +1,7 @@ """Registers the internal gym envs then loads the env plugins for module using the entry point.""" from typing import Any -from gymnasium.envs.registration import ( - load_plugin_envs, - make, - pprint_registry, - register, - registry, - spec, -) +from gymnasium.envs.registration import make, pprint_registry, register, registry, spec # Classic @@ -459,7 +452,3 @@ def _raise_shimmy_error(*args: Any, **kwargs: Any): # When installed, shimmy will re-register these environments with the correct entry_point register(id="GymV21Environment-v0", entry_point=_raise_shimmy_error) register(id="GymV26Environment-v0", entry_point=_raise_shimmy_error) - - -# Hook to load plugins from entry points -load_plugin_envs() diff --git a/gymnasium/envs/box2d/lunar_lander.py b/gymnasium/envs/box2d/lunar_lander.py index 4e092a3bc..87f90b57a 100644 --- a/gymnasium/envs/box2d/lunar_lander.py +++ b/gymnasium/envs/box2d/lunar_lander.py @@ -840,7 +840,7 @@ def heuristic(env, s): -(s[3]) * 0.5 ) # override to reduce fall speed, that's all we need after contact - if env.continuous: + if env.unwrapped.continuous: a = np.array([hover_todo * 20 - 1, -angle_todo * 20]) a = np.clip(a, -1, +1) else: diff --git a/gymnasium/envs/classic_control/cartpole.py b/gymnasium/envs/classic_control/cartpole.py index 37311823c..daebccf63 100644 --- a/gymnasium/envs/classic_control/cartpole.py +++ b/gymnasium/envs/classic_control/cartpole.py @@ -12,7 +12,7 @@ import gymnasium as gym from gymnasium import logger, spaces from gymnasium.envs.classic_control import utils from gymnasium.error import DependencyNotInstalled -from gymnasium.experimental.vector import VectorEnv +from gymnasium.vector import VectorEnv from gymnasium.vector.utils import batch_space @@ -74,13 +74,29 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]): ## Arguments - ```python - import gymnasium as gym - gym.make('CartPole-v1') - ``` + Cartpole only has ``render_mode`` as a keyword for ``gymnasium.make``. + On reset, the `options` parameter allows the user to change the bounds used to determine the new random state. - On reset, the `options` parameter allows the user to change the bounds used to determine - the new random state. + Examples: + >>> import gymnasium as gym + >>> env = gym.make("CartPole-v1", render_mode="rgb_array") + >>> env + >>>> + >>> env.reset(seed=123, options={"low": 0, "high": 1}) + (array([0.6823519 , 0.05382102, 0.22035988, 0.18437181], dtype=float32), {}) + + ## Vectorized environment + + To increase steps per seconds, users can use a custom vector environment or with an environment vectorizor. + + Examples: + >>> import gymnasium as gym + >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="vector_entry_point") + >>> envs + CartPoleVectorEnv(CartPole-v1, num_envs=3) + >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync") + >>> envs + SyncVectorEnv(CartPole-v1, num_envs=3) """ metadata = { @@ -328,8 +344,10 @@ class CartPoleVectorEnv(VectorEnv): max_episode_steps: int = 500, render_mode: Optional[str] = None, ): - super().__init__() self.num_envs = num_envs + self.max_episode_steps = max_episode_steps + self.render_mode = render_mode + self.gravity = 9.8 self.masscart = 1.0 self.masspole = 0.1 @@ -339,7 +357,6 @@ class CartPoleVectorEnv(VectorEnv): self.force_mag = 10.0 self.tau = 0.02 # seconds between state updates self.kinematics_integrator = "euler" - self.max_episode_steps = max_episode_steps self.steps = np.zeros(num_envs, dtype=np.int32) @@ -367,8 +384,6 @@ class CartPoleVectorEnv(VectorEnv): self.single_observation_space = spaces.Box(-high, high, dtype=np.float32) self.observation_space = batch_space(self.single_observation_space, num_envs) - self.render_mode = render_mode - self.screen_width = 600 self.screen_height = 400 self.screens = None @@ -464,6 +479,7 @@ class CartPoleVectorEnv(VectorEnv): def render(self): if self.render_mode is None: + assert self.spec is not None gym.logger.warn( "You are calling render method without specifying any render mode. " "You can specify the render_mode at initialization, " diff --git a/gymnasium/experimental/functional_jax_env.py b/gymnasium/envs/functional_jax_env.py similarity index 96% rename from gymnasium/experimental/functional_jax_env.py rename to gymnasium/envs/functional_jax_env.py index 631c69484..ee0ec7552 100644 --- a/gymnasium/experimental/functional_jax_env.py +++ b/gymnasium/envs/functional_jax_env.py @@ -10,10 +10,10 @@ import numpy as np import gymnasium as gym from gymnasium.envs.registration import EnvSpec -from gymnasium.experimental.functional import ActType, FuncEnv, StateType -from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy +from gymnasium.functional import ActType, FuncEnv, StateType from gymnasium.utils import seeding from gymnasium.vector.utils import batch_space +from gymnasium.wrappers.jax_to_numpy import jax_to_numpy class FunctionalJaxEnv(gym.Env): @@ -89,7 +89,7 @@ class FunctionalJaxEnv(gym.Env): observation = self.func_env.observation(next_state) reward = self.func_env.reward(self.state, action, next_state) terminated = self.func_env.terminal(next_state) - info = self.func_env.step_info(self.state, action, next_state) + info = self.func_env.transition_info(self.state, action, next_state) self.state = next_state observation = jax_to_numpy(observation) @@ -113,7 +113,7 @@ class FunctionalJaxEnv(gym.Env): self.render_state = None -class FunctionalJaxVectorEnv(gym.experimental.vector.VectorEnv): +class FunctionalJaxVectorEnv(gym.vector.VectorEnv): """A vector env implementation for functional Jax envs.""" state: StateType @@ -211,7 +211,7 @@ class FunctionalJaxVectorEnv(gym.experimental.vector.VectorEnv): else jnp.zeros_like(terminated) ) - info = self.func_env.step_info(self.state, action, next_state) + info = self.func_env.transition_info(self.state, action, next_state) done = jnp.logical_or(terminated, truncated) if jnp.any(done): diff --git a/gymnasium/envs/phys2d/cartpole.py b/gymnasium/envs/phys2d/cartpole.py index b8a38d382..1aca6ee3f 100644 --- a/gymnasium/envs/phys2d/cartpole.py +++ b/gymnasium/envs/phys2d/cartpole.py @@ -9,12 +9,9 @@ import numpy as np from jax.random import PRNGKey import gymnasium as gym +from gymnasium.envs.functional_jax_env import FunctionalJaxEnv, FunctionalJaxVectorEnv from gymnasium.error import DependencyNotInstalled -from gymnasium.experimental.functional import ActType, FuncEnv, StateType -from gymnasium.experimental.functional_jax_env import ( - FunctionalJaxEnv, - FunctionalJaxVectorEnv, -) +from gymnasium.functional import ActType, FuncEnv, StateType from gymnasium.utils import EzPickle diff --git a/gymnasium/envs/phys2d/pendulum.py b/gymnasium/envs/phys2d/pendulum.py index 2aa66c943..04d55dc10 100644 --- a/gymnasium/envs/phys2d/pendulum.py +++ b/gymnasium/envs/phys2d/pendulum.py @@ -10,12 +10,9 @@ import numpy as np from jax.random import PRNGKey import gymnasium as gym +from gymnasium.envs.functional_jax_env import FunctionalJaxEnv, FunctionalJaxVectorEnv from gymnasium.error import DependencyNotInstalled -from gymnasium.experimental.functional import ActType, FuncEnv, StateType -from gymnasium.experimental.functional_jax_env import ( - FunctionalJaxEnv, - FunctionalJaxVectorEnv, -) +from gymnasium.functional import ActType, FuncEnv, StateType from gymnasium.utils import EzPickle diff --git a/gymnasium/envs/registration.py b/gymnasium/envs/registration.py index 8d6d1e8c3..711d5f20e 100644 --- a/gymnasium/envs/registration.py +++ b/gymnasium/envs/registration.py @@ -10,7 +10,6 @@ import importlib.util import json import re import sys -import traceback from collections import defaultdict from dataclasses import dataclass, field from types import ModuleType @@ -86,11 +85,13 @@ class EnvSpec: * **nondeterministic**: If the observation of an environment cannot be repeated with the same initial state, random number generator state and actions. * **max_episode_steps**: The max number of steps that the environment can take before truncation * **order_enforce**: If to enforce the order of :meth:`gymnasium.Env.reset` before :meth:`gymnasium.Env.step` and :meth:`gymnasium.Env.render` functions - * **autoreset**: If to automatically reset the environment on episode end * **disable_env_checker**: If to disable the environment checker wrapper in :meth:`gymnasium.make`, by default False (runs the environment checker) * **kwargs**: Additional keyword arguments passed to the environment during initialisation * **additional_wrappers**: A tuple of additional wrappers applied to the environment (WrapperSpec) * **vector_entry_point**: The location of the vectorized environment to create from + + Changelogs: + v1.0.0 - Autoreset attribute removed """ id: str @@ -103,9 +104,7 @@ class EnvSpec: # Wrappers max_episode_steps: int | None = field(default=None) order_enforce: bool = field(default=True) - autoreset: bool = field(default=False) disable_env_checker: bool = field(default=False) - apply_api_compatibility: bool = field(default=False) # Environment arguments kwargs: dict = field(default_factory=dict) @@ -224,12 +223,8 @@ class EnvSpec: output += f"\nmax_episode_steps={self.max_episode_steps}" if print_all or self.order_enforce is not True: output += f"\norder_enforce={self.order_enforce}" - if print_all or self.autoreset is not False: - output += f"\nautoreset={self.autoreset}" if print_all or self.disable_env_checker is not False: output += f"\ndisable_env_checker={self.disable_env_checker}" - if print_all or self.apply_api_compatibility is not False: - output += f"\napplied_api_compatibility={self.apply_api_compatibility}" if print_all or self.additional_wrappers: wrapper_output: list[str] = [] @@ -547,55 +542,6 @@ def load_env_creator(name: str) -> EnvCreator | VectorEnvCreator: return fn -def load_plugin_envs(entry_point: str = "gymnasium.envs"): - """Load modules (plugins) using the gymnasium entry points in order to register external module's environments on ``import gymnasium``. - - Args: - entry_point: The string for the entry point. - """ - # Load third-party environments - for plugin in metadata.entry_points(group=entry_point): - # Python 3.8 doesn't support plugin.module, plugin.attr - # So we'll have to try and parse this ourselves - module, attr = None, None - try: - module, attr = plugin.module, plugin.attr # type: ignore ## error: Cannot access member "attr" for type "EntryPoint" - except AttributeError: - if ":" in plugin.value: - module, attr = plugin.value.split(":", maxsplit=1) - else: - module, attr = plugin.value, None - except Exception as e: - logger.warn( - f"While trying to load plugin `{plugin}` from {entry_point}, an exception occurred: {e}" - ) - module, attr = None, None - finally: - if attr is None: - raise error.Error( - f"Gymnasium environment plugin `{module}` must specify a function to execute, not a root module" - ) - - context = namespace(plugin.name) - if plugin.name.startswith("__") and plugin.name.endswith("__"): - # `__internal__` is an artifact of the plugin system when the root namespace had an allow-list. - # The allow-list is now removed and plugins can register environments in the root namespace with the `__root__` magic key. - if plugin.name == "__root__" or plugin.name == "__internal__": - context = contextlib.nullcontext() - else: - logger.warn( - f"The environment namespace magic key `{plugin.name}` is unsupported. " - "To register an environment at the root namespace you should specify the `__root__` namespace." - ) - - with context: - fn = plugin.load() - try: - fn() - except Exception: - logger.warn(f"plugin: {plugin.value} raised {traceback.format_exc()}") - - def register_envs(env_module: ModuleType): """A No-op function such that it can appear to IDEs that a module is used.""" pass @@ -618,9 +564,7 @@ def register( nondeterministic: bool = False, max_episode_steps: int | None = None, order_enforce: bool = True, - autoreset: bool = False, disable_env_checker: bool = False, - apply_api_compatibility: bool = False, additional_wrappers: tuple[WrapperSpec, ...] = (), vector_entry_point: VectorEnvCreator | str | None = None, **kwargs: Any, @@ -640,13 +584,13 @@ def register( max_episode_steps: The maximum number of episodes steps before truncation. Used by the :class:`gymnasium.wrappers.TimeLimit` wrapper if not ``None``. order_enforce: If to enable the order enforcer wrapper to ensure users run functions in the correct order. If ``True``, then the :class:`gymnasium.wrappers.OrderEnforcing` is applied to the environment. - autoreset: If to add the :class:`gymnasium.wrappers.AutoResetWrapper` such that on ``(terminated or truncated) is True``, :meth:`gymnasium.Env.reset` is called. disable_env_checker: If to disable the :class:`gymnasium.wrappers.PassiveEnvChecker` to the environment. - apply_api_compatibility: If to apply the :class:`gymnasium.wrappers.StepAPICompatibility` wrapper to the environment. - Use if the environment is implemented in the gym v0.21 environment API. additional_wrappers: Additional wrappers to apply the environment. vector_entry_point: The entry point for creating the vector environment **kwargs: arbitrary keyword arguments which are passed to the environment constructor on initialisation. + + Changelogs: + v1.0.0 - `autoreset` and `apply_api_compatibility` parameter was removed """ assert ( entry_point is not None or vector_entry_point is not None @@ -669,11 +613,6 @@ def register( ns_id = ns full_env_id = get_env_id(ns_id, name, version) - if autoreset is True: - logger.warn( - "`gymnasium.register(..., autoreset=True)` is deprecated and will be removed in v1.0. If users wish to use it then add the auto reset wrapper in the `addition_wrappers` argument." - ) - new_spec = EnvSpec( id=full_env_id, entry_point=entry_point, @@ -681,9 +620,7 @@ def register( nondeterministic=nondeterministic, max_episode_steps=max_episode_steps, order_enforce=order_enforce, - autoreset=autoreset, disable_env_checker=disable_env_checker, - apply_api_compatibility=apply_api_compatibility, **kwargs, additional_wrappers=additional_wrappers, vector_entry_point=vector_entry_point, @@ -698,8 +635,6 @@ def register( def make( id: str | EnvSpec, max_episode_steps: int | None = None, - autoreset: bool | None = None, - apply_api_compatibility: bool | None = None, disable_env_checker: bool | None = None, **kwargs: Any, ) -> Env: @@ -710,12 +645,9 @@ def make( Args: id: A string for the environment id or a :class:`EnvSpec`. Optionally if using a string, a module to import can be included, e.g. ``'module:Env-v0'``. This is equivalent to importing the module first to register the environment followed by making the environment. - max_episode_steps: Maximum length of an episode, can override the registered :class:`EnvSpec` ``max_episode_steps``. - The value is used by :class:`gymnasium.wrappers.TimeLimit`. - autoreset: Whether to automatically reset the environment after each episode (:class:`gymnasium.wrappers.AutoResetWrapper`). - apply_api_compatibility: Whether to wrap the environment with the :class:`gymnasium.wrappers.StepAPICompatibility` wrapper that - converts the environment step from a done bool to return termination and truncation bools. - By default, the argument is None in which the :class:`EnvSpec` ``apply_api_compatibility`` is used, otherwise this variable is used in favor. + max_episode_steps: Maximum length of an episode, can override the registered :class:`EnvSpec` ``max_episode_steps`` + with the value being passed to :class:`gymnasium.wrappers.TimeLimit`. + Using ``max_episode_steps=-1`` will not apply the wrapper to the environment. disable_env_checker: If to add :class:`gymnasium.wrappers.PassiveEnvChecker`, ``None`` will default to the :class:`EnvSpec` ``disable_env_checker`` value otherwise use this value will be used. kwargs: Additional arguments to pass to the environment constructor. @@ -725,6 +657,9 @@ def make( Raises: Error: If the ``id`` doesn't exist in the :attr:`registry` + + Changelogs: + v1.0.0 - `autoreset` and `apply_api_compatibility` was removed """ if isinstance(id, EnvSpec): env_spec = id @@ -790,14 +725,6 @@ def make( f"that is not in the possible render_modes ({render_modes})." ) - if apply_api_compatibility or ( - apply_api_compatibility is None and env_spec.apply_api_compatibility - ): - # If we use the compatibility layer, we treat the render mode explicitly and don't pass it to the env creator - render_mode = env_spec_kwargs.pop("render_mode", None) - else: - render_mode = None - try: env = env_creator(**env_spec_kwargs) except TypeError as e: @@ -823,9 +750,7 @@ def make( nondeterministic=env_spec.nondeterministic, max_episode_steps=None, order_enforce=False, - autoreset=False, disable_env_checker=True, - apply_api_compatibility=False, kwargs=env_spec_kwargs, additional_wrappers=(), vector_entry_point=env_spec.vector_entry_point, @@ -845,15 +770,6 @@ def make( f"The environment's wrapper spec {recreated_wrapper_spec} is different from the saved `EnvSpec` additional wrapper {env_spec_wrapper_spec}" ) - # Add step API wrapper - if apply_api_compatibility is True or ( - apply_api_compatibility is None and env_spec.apply_api_compatibility is True - ): - logger.warn( - "`gymnasium.make(..., apply_api_compatibility=True)` and `env_spec.apply_api_compatibility` is deprecated and will be removed in v1.0" - ) - env = gym.wrappers.EnvCompatibility(env, render_mode) - # Run the environment checker as the lowest level wrapper if disable_env_checker is False or ( disable_env_checker is None and env_spec.disable_env_checker is False @@ -865,18 +781,11 @@ def make( env = gym.wrappers.OrderEnforcing(env) # Add the time limit wrapper - if max_episode_steps is not None: - env = gym.wrappers.TimeLimit(env, max_episode_steps) - elif env_spec.max_episode_steps is not None: - env = gym.wrappers.TimeLimit(env, env_spec.max_episode_steps) - - # Add the auto-reset wrapper - if autoreset is True or (autoreset is None and env_spec.autoreset is True): - env = gym.wrappers.AutoResetWrapper(env) - - logger.warn( - "`gymnasium.make(..., autoreset=True)` is deprecated and will be removed in v1.0" - ) + if max_episode_steps != -1: + if max_episode_steps is not None: + env = gym.wrappers.TimeLimit(env, max_episode_steps) + elif env_spec.max_episode_steps is not None: + env = gym.wrappers.TimeLimit(env, env_spec.max_episode_steps) for wrapper_spec in env_spec.additional_wrappers[num_prior_wrappers:]: if wrapper_spec.kwargs is None: @@ -898,25 +807,25 @@ def make( def make_vec( id: str | EnvSpec, num_envs: int = 1, - vectorization_mode: str = "async", + vectorization_mode: str | None = None, vector_kwargs: dict[str, Any] | None = None, wrappers: Sequence[Callable[[Env], Wrapper]] | None = None, **kwargs, -) -> gym.experimental.vector.VectorEnv: +) -> gym.vector.VectorEnv: """Create a vector environment according to the given ID. - Note: - This feature is experimental, and is likely to change in future releases. - - To find all available environments use `gymnasium.envs.registry.keys()` for all valid ids. + To find all available environments use :func:`gymnasium.pprint_registry` or ``gymnasium.registry.keys()`` for all valid ids. + We refer to the Vector environment as the vectorizor while the environment being vectorized is the base or vectorized environment (``vectorizor(vectorized env)``). Args: id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0' num_envs: Number of environments to create - vectorization_mode: How to vectorize the environment. Can be either "async", "sync" or "custom" - vector_kwargs: Additional arguments to pass to the vectorized environment constructor. - wrappers: A sequence of wrapper functions to apply to the environment. Can only be used in "sync" or "async" mode. - **kwargs: Additional arguments to pass to the environment constructor. + vectorization_mode: The vectorization method used, defaults to ``None`` such that if a ``vector_entry_point`` exists, + this is first used otherwise defaults to ``sync`` to use the :class:`gymnasium.vector.SyncVectorEnv`. + Valid modes are ``"async"``, ``"sync"`` or ``"vector_entry_point"``. + vector_kwargs: Additional arguments to pass to the vectorizor environment constructor, i.e., ``SyncVectorEnv(..., **vector_kwargs)``. + wrappers: A sequence of wrapper functions to apply to the base environment. Can only be used in ``"sync"`` or ``"async"`` mode. + **kwargs: Additional arguments passed to the base environment constructor. Returns: An instance of the environment. @@ -926,87 +835,93 @@ def make_vec( """ if vector_kwargs is None: vector_kwargs = {} - if wrappers is None: wrappers = [] if isinstance(id, EnvSpec): - spec_ = id + id_env_spec = id + env_spec_kwargs = id_env_spec.kwargs.copy() + + num_envs = env_spec_kwargs.pop("num_envs", num_envs) + vectorization_mode = env_spec_kwargs.pop( + "vectorization_mode", vectorization_mode + ) + vector_kwargs = env_spec_kwargs.pop("vector_kwargs", vector_kwargs) + wrappers = env_spec_kwargs.pop("wrappers", wrappers) else: - spec_ = _find_spec(id) + id_env_spec = _find_spec(id) + env_spec_kwargs = id_env_spec.kwargs.copy() - _kwargs = spec_.kwargs.copy() - _kwargs.update(kwargs) + env_spec_kwargs.update(kwargs) - # Check if we have the necessary entry point - if vectorization_mode in ("sync", "async"): - if spec_.entry_point is None: - raise error.Error( - f"Cannot create vectorized environment for {id} because it doesn't have an entry point defined." - ) - entry_point = spec_.entry_point - elif vectorization_mode in ("custom",): - if spec_.vector_entry_point is None: - raise error.Error( - f"Cannot create vectorized environment for {id} because it doesn't have a vector entry point defined." - ) - entry_point = spec_.vector_entry_point - else: - raise error.Error(f"Invalid vectorization mode: {vectorization_mode}") + # Update the vectorization_mode if None + if vectorization_mode is None: + if id_env_spec.vector_entry_point is not None: + vectorization_mode = "vector_entry_point" + else: + vectorization_mode = "sync" - if callable(entry_point): - env_creator = entry_point - else: - # Assume it's a string - env_creator = load_env_creator(entry_point) - - def _create_env(): - # Env creator for use with sync and async modes - _kwargs_copy = _kwargs.copy() - - render_mode = _kwargs.get("render_mode", None) - if render_mode is not None: - inner_render_mode = ( - render_mode[: -len("_list")] - if render_mode.endswith("_list") - else render_mode - ) - _kwargs_copy["render_mode"] = inner_render_mode - - _env = env_creator(**_kwargs_copy) - _env.spec = spec_ - if spec_.max_episode_steps is not None: - _env = gym.wrappers.TimeLimit(_env, spec_.max_episode_steps) - - if render_mode is not None and render_mode.endswith("_list"): - _env = gym.wrappers.RenderCollection(_env) + def create_single_env() -> Env: + single_env = make(id_env_spec.id, **env_spec_kwargs.copy()) for wrapper in wrappers: - _env = wrapper(_env) - return _env + single_env = wrapper(single_env) + return single_env if vectorization_mode == "sync": - env = gym.experimental.vector.SyncVectorEnv( - env_fns=[_create_env for _ in range(num_envs)], + if id_env_spec.entry_point is None: + raise error.Error( + f"Cannot create vectorized environment for {id_env_spec.id} because it doesn't have an entry point defined." + ) + + env = gym.vector.SyncVectorEnv( + env_fns=(create_single_env for _ in range(num_envs)), **vector_kwargs, ) elif vectorization_mode == "async": - env = gym.experimental.vector.AsyncVectorEnv( - env_fns=[_create_env for _ in range(num_envs)], + if id_env_spec.entry_point is None: + raise error.Error( + f"Cannot create vectorized environment for {id_env_spec.id} because it doesn't have an entry point defined." + ) + + env = gym.vector.AsyncVectorEnv( + env_fns=[create_single_env for _ in range(num_envs)], **vector_kwargs, ) - elif vectorization_mode == "custom": + elif vectorization_mode == "vector_entry_point": + entry_point = id_env_spec.vector_entry_point + if entry_point is None: + raise error.Error( + f"Cannot create vectorized environment for {id} because it doesn't have a vector entry point defined." + ) + elif callable(entry_point): + env_creator = entry_point + else: # Assume it's a string + env_creator = load_env_creator(entry_point) + if len(wrappers) > 0: - raise error.Error("Cannot use custom vectorization mode with wrappers.") - vector_kwargs["max_episode_steps"] = spec_.max_episode_steps + raise error.Error( + "Cannot use `vector_entry_point` vectorization mode with the wrappers argument." + ) + if "max_episode_steps" not in vector_kwargs: + vector_kwargs["max_episode_steps"] = id_env_spec.max_episode_steps + env = env_creator(num_envs=num_envs, **vector_kwargs) else: raise error.Error(f"Invalid vectorization mode: {vectorization_mode}") # Copies the environment creation specification and kwargs to add to the environment specification details - spec_ = copy.deepcopy(spec_) - spec_.kwargs = _kwargs - env.unwrapped.spec = spec_ + copied_id_spec = copy.deepcopy(id_env_spec) + copied_id_spec.kwargs = env_spec_kwargs + if num_envs != 1: + copied_id_spec.kwargs["num_envs"] = num_envs + if vectorization_mode != "async": + copied_id_spec.kwargs["vectorization_mode"] = vectorization_mode + if vector_kwargs is not None: + copied_id_spec.kwargs["vector_kwargs"] = vector_kwargs + if wrappers is not None: + copied_id_spec.kwargs["wrappers"] = wrappers + env.unwrapped.spec = copied_id_spec return env diff --git a/gymnasium/envs/tabular/blackjack.py b/gymnasium/envs/tabular/blackjack.py index 4ee81f431..bfdf341dc 100644 --- a/gymnasium/envs/tabular/blackjack.py +++ b/gymnasium/envs/tabular/blackjack.py @@ -12,9 +12,9 @@ from jax import random from jax.random import PRNGKey from gymnasium import spaces +from gymnasium.envs.functional_jax_env import FunctionalJaxEnv from gymnasium.error import DependencyNotInstalled -from gymnasium.experimental.functional import ActType, FuncEnv, StateType -from gymnasium.experimental.functional_jax_env import FunctionalJaxEnv +from gymnasium.functional import ActType, FuncEnv, StateType from gymnasium.utils import EzPickle, seeding from gymnasium.wrappers import HumanRendering diff --git a/gymnasium/envs/tabular/cliffwalking.py b/gymnasium/envs/tabular/cliffwalking.py index 2b3da1ffc..d887e75ca 100644 --- a/gymnasium/envs/tabular/cliffwalking.py +++ b/gymnasium/envs/tabular/cliffwalking.py @@ -12,9 +12,9 @@ import numpy as np from jax.random import PRNGKey from gymnasium import spaces +from gymnasium.envs.functional_jax_env import FunctionalJaxEnv from gymnasium.error import DependencyNotInstalled -from gymnasium.experimental.functional import ActType, FuncEnv, StateType -from gymnasium.experimental.functional_jax_env import FunctionalJaxEnv +from gymnasium.functional import ActType, FuncEnv, StateType from gymnasium.utils import EzPickle from gymnasium.wrappers import HumanRendering diff --git a/gymnasium/experimental/__init__.py b/gymnasium/experimental/__init__.py deleted file mode 100644 index 6157cfcd2..000000000 --- a/gymnasium/experimental/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Root __init__ of the gym experimental wrappers.""" - - -from gymnasium.experimental import functional, vector, wrappers - - -# from gymnasium.experimental.functional import FuncEnv -# from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv -# from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv -# from gymnasium.experimental.vector.vector_env import VectorEnv, VectorWrapper - - -__all__ = [ - # Functional - # "FuncEnv", - "functional", - # Vector - # "VectorEnv", - # "VectorWrapper", - # "SyncVectorEnv", - # "AsyncVectorEnv", - # wrappers - "wrappers", - "vector", -] diff --git a/gymnasium/experimental/vector/__init__.py b/gymnasium/experimental/vector/__init__.py deleted file mode 100644 index 0f5839a03..000000000 --- a/gymnasium/experimental/vector/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Experimental vector env API.""" -from gymnasium.experimental.vector import utils -from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv -from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv -from gymnasium.experimental.vector.vector_env import ( - VectorActionWrapper, - VectorEnv, - VectorObservationWrapper, - VectorRewardWrapper, - VectorWrapper, -) - - -__all__ = [ - "VectorEnv", - "VectorWrapper", - "VectorObservationWrapper", - "VectorActionWrapper", - "VectorRewardWrapper", - "SyncVectorEnv", - "AsyncVectorEnv", - "utils", -] diff --git a/gymnasium/experimental/vector/async_vector_env.py b/gymnasium/experimental/vector/async_vector_env.py deleted file mode 100644 index 1d5e03620..000000000 --- a/gymnasium/experimental/vector/async_vector_env.py +++ /dev/null @@ -1,685 +0,0 @@ -"""An async vector environment.""" -from __future__ import annotations - -import multiprocessing -import sys -import time -from copy import deepcopy -from enum import Enum -from multiprocessing import Queue -from multiprocessing.connection import Connection -from typing import Any, Callable, Sequence - -import numpy as np - -from gymnasium import logger -from gymnasium.core import Env, ObsType -from gymnasium.error import ( - AlreadyPendingCallError, - ClosedEnvironmentError, - CustomSpaceError, - NoAsyncCallError, -) -from gymnasium.experimental.vector.utils import ( - CloudpickleWrapper, - batch_space, - clear_mpi_env_vars, - concatenate, - create_empty_array, - create_shared_memory, - iterate, - read_from_shared_memory, - write_to_shared_memory, -) -from gymnasium.experimental.vector.vector_env import VectorEnv - - -__all__ = ["AsyncVectorEnv"] - - -class AsyncState(Enum): - DEFAULT = "default" - WAITING_RESET = "reset" - WAITING_STEP = "step" - WAITING_CALL = "call" - - -class AsyncVectorEnv(VectorEnv): - """Vectorized environment that runs multiple environments in parallel. - - It uses ``multiprocessing`` processes, and pipes for communication. - - Example: - >>> import gymnasium as gym - >>> env = gym.vector.AsyncVectorEnv([ - ... lambda: gym.make("Pendulum-v1", g=9.81), - ... lambda: gym.make("Pendulum-v1", g=1.62) - ... ]) - >>> env.reset(seed=42) - (array([[-0.14995256, 0.9886932 , -0.12224312], - [ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32), {}) - """ - - def __init__( - self, - env_fns: Sequence[Callable[[], Env]], - shared_memory: bool = True, - copy: bool = True, - context: str | None = None, - daemon: bool = True, - worker: callable | None = None, - ): - """Vectorized environment that runs multiple environments in parallel. - - Args: - env_fns: Functions that create the environments. - shared_memory: If ``True``, then the observations from the worker processes are communicated back through - shared variables. This can improve the efficiency if the observations are large (e.g. images). - copy: If ``True``, then the :meth:`~AsyncVectorEnv.reset` and :meth:`~AsyncVectorEnv.step` methods - return a copy of the observations. - context: Context for `multiprocessing`_. If ``None``, then the default context is used. - daemon: If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they will quit if - the head process quits. However, ``daemon=True`` prevents subprocesses to spawn children, - so for some environments you may want to have it set to ``False``. - worker: If set, then use that worker in a subprocess instead of a default one. - Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled. - - Warnings: - worker is an advanced mode option. It provides a high degree of flexibility and a high chance - to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start - from the code for ``_worker`` (or ``_worker_shared_memory``) method, and add changes. - - Raises: - RuntimeError: If the observation space of some sub-environment does not match observation_space - (or, by default, the observation space of the first sub-environment). - ValueError: If observation_space is a custom space (i.e. not a default space in Gym, - such as gymnasium.spaces.Box, gymnasium.spaces.Discrete, or gymnasium.spaces.Dict) and shared_memory is True. - """ - super().__init__() - - ctx = multiprocessing.get_context(context) - self.env_fns = env_fns - self.num_envs = len(env_fns) - self.shared_memory = shared_memory - self.copy = copy - - # This would be nice to get rid of, but without it there's a deadlock between shared memory and pipes - dummy_env = env_fns[0]() - self.metadata = dummy_env.metadata - - self.single_observation_space = dummy_env.observation_space - self.single_action_space = dummy_env.action_space - - self.observation_space = batch_space( - self.single_observation_space, self.num_envs - ) - self.action_space = batch_space(self.single_action_space, self.num_envs) - - dummy_env.close() - del dummy_env - - if self.shared_memory: - try: - _obs_buffer = create_shared_memory( - self.single_observation_space, n=self.num_envs, ctx=ctx - ) - self.observations = read_from_shared_memory( - self.single_observation_space, _obs_buffer, n=self.num_envs - ) - except CustomSpaceError as e: - raise ValueError( - "Using `shared_memory=True` in `AsyncVectorEnv` " - "is incompatible with non-standard Gymnasium observation spaces " - "(i.e. custom spaces inheriting from `gymnasium.Space`), and is " - "only compatible with default Gymnasium spaces (e.g. `Box`, " - "`Tuple`, `Dict`) for batching. Set `shared_memory=False` " - "if you use custom observation spaces." - ) from e - else: - _obs_buffer = None - self.observations = create_empty_array( - self.single_observation_space, n=self.num_envs, fn=np.zeros - ) - - self.parent_pipes, self.processes = [], [] - self.error_queue = ctx.Queue() - target = worker or _worker - with clear_mpi_env_vars(): - for idx, env_fn in enumerate(self.env_fns): - parent_pipe, child_pipe = ctx.Pipe() - process = ctx.Process( - target=target, - name=f"Worker<{type(self).__name__}>-{idx}", - args=( - idx, - CloudpickleWrapper(env_fn), - child_pipe, - parent_pipe, - _obs_buffer, - self.error_queue, - ), - ) - - self.parent_pipes.append(parent_pipe) - self.processes.append(process) - - process.daemon = daemon - process.start() - child_pipe.close() - - self._state = AsyncState.DEFAULT - self._check_spaces() - - def reset_async( - self, - seed: int | list[int] | None = None, - options: dict | None = None, - ): - """Send calls to the :obj:`reset` methods of the sub-environments. - - To get the results of these calls, you may invoke :meth:`reset_wait`. - - Args: - seed: List of seeds for each environment - options: The reset option - - Raises: - ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called). - AlreadyPendingCallError: If the environment is already waiting for a pending call to another - method (e.g. :meth:`step_async`). This can be caused by two consecutive - calls to :meth:`reset_async`, with no call to :meth:`reset_wait` in between. - """ - self._assert_is_running() - - if seed is None: - seed = [None for _ in range(self.num_envs)] - if isinstance(seed, int): - seed = [seed + i for i in range(self.num_envs)] - assert len(seed) == self.num_envs - - if self._state != AsyncState.DEFAULT: - raise AlreadyPendingCallError( - f"Calling `reset_async` while waiting for a pending call to `{self._state.value}` to complete", - str(self._state.value), - ) - - for pipe, single_seed in zip(self.parent_pipes, seed): - single_kwargs = {} - if single_seed is not None: - single_kwargs["seed"] = single_seed - if options is not None: - single_kwargs["options"] = options - - pipe.send(("reset", single_kwargs)) - self._state = AsyncState.WAITING_RESET - - def reset_wait( - self, - timeout: int | float | None = None, - ) -> tuple[ObsType, list[dict]]: - """Waits for the calls triggered by :meth:`reset_async` to finish and returns the results. - - Args: - timeout: Number of seconds before the call to `reset_wait` times out. If `None`, the call to `reset_wait` never times out. - - Returns: - A tuple of batched observations and list of dictionaries - - Raises: - ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called). - NoAsyncCallError: If :meth:`reset_wait` was called without any prior call to :meth:`reset_async`. - TimeoutError: If :meth:`reset_wait` timed out. - """ - self._assert_is_running() - if self._state != AsyncState.WAITING_RESET: - raise NoAsyncCallError( - "Calling `reset_wait` without any prior " "call to `reset_async`.", - AsyncState.WAITING_RESET.value, - ) - - if not self._poll(timeout): - self._state = AsyncState.DEFAULT - raise multiprocessing.TimeoutError( - f"The call to `reset_wait` has timed out after {timeout} second(s)." - ) - - results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) - self._raise_if_errors(successes) - self._state = AsyncState.DEFAULT - - infos = {} - results, info_data = zip(*results) - for i, info in enumerate(info_data): - infos = self._add_info(infos, info, i) - - if not self.shared_memory: - self.observations = concatenate( - self.single_observation_space, results, self.observations - ) - - return (deepcopy(self.observations) if self.copy else self.observations), infos - - def reset( - self, - *, - seed: int | list[int] | None = None, - options: dict | None = None, - ): - """Reset all parallel environments and return a batch of initial observations and info. - - Args: - seed: The environment reset seeds - options: If to return the options - - Returns: - A batch of observations and info from the vectorized environment. - """ - self.reset_async(seed=seed, options=options) - return self.reset_wait() - - def step_async(self, actions: np.ndarray): - """Send the calls to :obj:`step` to each sub-environment. - - Args: - actions: Batch of actions. element of :attr:`~VectorEnv.action_space` - - Raises: - ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called). - AlreadyPendingCallError: If the environment is already waiting for a pending call to another - method (e.g. :meth:`reset_async`). This can be caused by two consecutive - calls to :meth:`step_async`, with no call to :meth:`step_wait` in - between. - """ - self._assert_is_running() - if self._state != AsyncState.DEFAULT: - raise AlreadyPendingCallError( - f"Calling `step_async` while waiting for a pending call to `{self._state.value}` to complete.", - str(self._state.value), - ) - - actions = iterate(self.action_space, actions) - for pipe, action in zip(self.parent_pipes, actions): - pipe.send(("step", action)) - self._state = AsyncState.WAITING_STEP - - def step_wait( - self, timeout: int | float | None = None - ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, dict]: - """Wait for the calls to :obj:`step` in each sub-environment to finish. - - Args: - timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out. - - Returns: - The batched environment step information, (obs, reward, terminated, truncated, info) - - Raises: - ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called). - NoAsyncCallError: If :meth:`step_wait` was called without any prior call to :meth:`step_async`. - TimeoutError: If :meth:`step_wait` timed out. - """ - self._assert_is_running() - if self._state != AsyncState.WAITING_STEP: - raise NoAsyncCallError( - "Calling `step_wait` without any prior call " "to `step_async`.", - AsyncState.WAITING_STEP.value, - ) - - if not self._poll(timeout): - self._state = AsyncState.DEFAULT - raise multiprocessing.TimeoutError( - f"The call to `step_wait` has timed out after {timeout} second(s)." - ) - - observations_list, rewards, terminateds, truncateds, infos = [], [], [], [], {} - successes = [] - for i, pipe in enumerate(self.parent_pipes): - result, success = pipe.recv() - obs, rew, terminated, truncated, info = result - - successes.append(success) - if success: - observations_list.append(obs) - rewards.append(rew) - terminateds.append(terminated) - truncateds.append(truncated) - infos = self._add_info(infos, info, i) - - self._raise_if_errors(successes) - self._state = AsyncState.DEFAULT - - if not self.shared_memory: - self.observations = concatenate( - self.single_observation_space, - observations_list, - self.observations, - ) - - return ( - deepcopy(self.observations) if self.copy else self.observations, - np.array(rewards), - np.array(terminateds, dtype=np.bool_), - np.array(truncateds, dtype=np.bool_), - infos, - ) - - def step(self, actions): - """Take an action for each parallel environment. - - Args: - actions: element of :attr:`action_space` Batch of actions. - - Returns: - Batch of (observations, rewards, terminations, truncations, infos) - """ - self.step_async(actions) - return self.step_wait() - - def call_async(self, name: str, *args, **kwargs): - """Calls the method with name asynchronously and apply args and kwargs to the method. - - Args: - name: Name of the method or property to call. - *args: Arguments to apply to the method call. - **kwargs: Keyword arguments to apply to the method call. - - Raises: - ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called). - AlreadyPendingCallError: Calling `call_async` while waiting for a pending call to complete - """ - self._assert_is_running() - if self._state != AsyncState.DEFAULT: - raise AlreadyPendingCallError( - "Calling `call_async` while waiting " - f"for a pending call to `{self._state.value}` to complete.", - str(self._state.value), - ) - - for pipe in self.parent_pipes: - pipe.send(("_call", (name, args, kwargs))) - self._state = AsyncState.WAITING_CALL - - def call_wait(self, timeout: int | float | None = None) -> list: - """Calls all parent pipes and waits for the results. - - Args: - timeout: Number of seconds before the call to `step_wait` times out. - If `None` (default), the call to `step_wait` never times out. - - Returns: - List of the results of the individual calls to the method or property for each environment. - - Raises: - NoAsyncCallError: Calling `call_wait` without any prior call to `call_async`. - TimeoutError: The call to `call_wait` has timed out after timeout second(s). - """ - self._assert_is_running() - if self._state != AsyncState.WAITING_CALL: - raise NoAsyncCallError( - "Calling `call_wait` without any prior call to `call_async`.", - AsyncState.WAITING_CALL.value, - ) - - if not self._poll(timeout): - self._state = AsyncState.DEFAULT - raise multiprocessing.TimeoutError( - f"The call to `call_wait` has timed out after {timeout} second(s)." - ) - - results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) - self._raise_if_errors(successes) - self._state = AsyncState.DEFAULT - - return results - - def call(self, name: str, *args, **kwargs) -> list[Any]: - """Call a method, or get a property, from each parallel environment. - - Args: - name (str): Name of the method or property to call. - *args: Arguments to apply to the method call. - **kwargs: Keyword arguments to apply to the method call. - - Returns: - List of the results of the individual calls to the method or property for each environment. - """ - self.call_async(name, *args, **kwargs) - return self.call_wait() - - def get_attr(self, name: str): - """Get a property from each parallel environment. - - Args: - name (str): Name of the property to be get from each individual environment. - - Returns: - The property with name - """ - return self.call(name) - - def set_attr(self, name: str, values: list[Any] | tuple[Any] | object): - """Sets an attribute of the sub-environments. - - Args: - name: Name of the property to be set in each individual environment. - values: Values of the property to be set to. If ``values`` is a list or - tuple, then it corresponds to the values for each individual - environment, otherwise a single value is set for all environments. - - Raises: - ValueError: Values must be a list or tuple with length equal to the number of environments. - AlreadyPendingCallError: Calling `set_attr` while waiting for a pending call to complete. - """ - self._assert_is_running() - if not isinstance(values, (list, tuple)): - values = [values for _ in range(self.num_envs)] - if len(values) != self.num_envs: - raise ValueError( - "Values must be a list or tuple with length equal to the " - f"number of environments. Got `{len(values)}` values for " - f"{self.num_envs} environments." - ) - - if self._state != AsyncState.DEFAULT: - raise AlreadyPendingCallError( - "Calling `set_attr` while waiting " - f"for a pending call to `{self._state.value}` to complete.", - str(self._state.value), - ) - - for pipe, value in zip(self.parent_pipes, values): - pipe.send(("_setattr", (name, value))) - _, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) - self._raise_if_errors(successes) - - def close_extras(self, timeout: int | float | None = None, terminate: bool = False): - """Close the environments & clean up the extra resources (processes and pipes). - - Args: - timeout: Number of seconds before the call to :meth:`close` times out. If ``None``, - the call to :meth:`close` never times out. If the call to :meth:`close` - times out, then all processes are terminated. - terminate: If ``True``, then the :meth:`close` operation is forced and all processes are terminated. - - Raises: - TimeoutError: If :meth:`close` timed out. - """ - timeout = 0 if terminate else timeout - try: - if self._state != AsyncState.DEFAULT: - logger.warn( - f"Calling `close` while waiting for a pending call to `{self._state.value}` to complete." - ) - function = getattr(self, f"{self._state.value}_wait") - function(timeout) - except multiprocessing.TimeoutError: - terminate = True - - if terminate: - for process in self.processes: - if process.is_alive(): - process.terminate() - else: - for pipe in self.parent_pipes: - if (pipe is not None) and (not pipe.closed): - pipe.send(("close", None)) - for pipe in self.parent_pipes: - if (pipe is not None) and (not pipe.closed): - pipe.recv() - - for pipe in self.parent_pipes: - if pipe is not None: - pipe.close() - for process in self.processes: - process.join() - - def _poll(self, timeout=None): - self._assert_is_running() - if timeout is None: - return True - end_time = time.perf_counter() + timeout - delta = None - for pipe in self.parent_pipes: - delta = max(end_time - time.perf_counter(), 0) - if pipe is None: - return False - if pipe.closed or (not pipe.poll(delta)): - return False - return True - - def _check_spaces(self): - self._assert_is_running() - spaces = (self.single_observation_space, self.single_action_space) - for pipe in self.parent_pipes: - pipe.send(("_check_spaces", spaces)) - results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) - self._raise_if_errors(successes) - same_observation_spaces, same_action_spaces = zip(*results) - if not all(same_observation_spaces): - raise RuntimeError( - f"Some environments have an observation space different from `{self.single_observation_space}`. " - "In order to batch observations, the observation spaces from all environments must be equal." - ) - if not all(same_action_spaces): - raise RuntimeError( - f"Some environments have an action space different from `{self.single_action_space}`. " - "In order to batch actions, the action spaces from all environments must be equal." - ) - - def _assert_is_running(self): - if self.closed: - raise ClosedEnvironmentError( - f"Trying to operate on `{type(self).__name__}`, after a call to `close()`." - ) - - def _raise_if_errors(self, successes: list[bool]): - if all(successes): - return - - num_errors = self.num_envs - sum(successes) - assert num_errors > 0 - for i in range(num_errors): - index, exctype, value = self.error_queue.get() - logger.error( - f"Received the following error from Worker-{index}: {exctype.__name__}: {value}" - ) - logger.error(f"Shutting down Worker-{index}.") - self.parent_pipes[index].close() - self.parent_pipes[index] = None - - if i == num_errors - 1: - logger.error("Raising the last exception back to the main process.") - raise exctype(value) - - def __del__(self): - """On deleting the object, checks that the vector environment is closed.""" - if not getattr(self, "closed", True) and hasattr(self, "_state"): - self.close(terminate=True) - - -def _worker( - index: int, - env_fn: callable, - pipe: Connection, - parent_pipe: Connection, - shared_memory: bool, - error_queue: Queue, -): - env = env_fn() - observation_space = env.observation_space - action_space = env.action_space - parent_pipe.close() - try: - while True: - command, data = pipe.recv() - - if command == "reset": - observation, info = env.reset(**data) - if shared_memory: - write_to_shared_memory( - observation_space, index, observation, shared_memory - ) - observation = None - pipe.send(((observation, info), True)) - - elif command == "step": - ( - observation, - reward, - terminated, - truncated, - info, - ) = env.step(data) - if terminated or truncated: - old_observation, old_info = observation, info - observation, info = env.reset() - info["final_observation"] = old_observation - info["final_info"] = old_info - - if shared_memory: - write_to_shared_memory( - observation_space, index, observation, shared_memory - ) - observation = None - - pipe.send(((observation, reward, terminated, truncated, info), True)) - elif command == "seed": - env.seed(data) - pipe.send((None, True)) - elif command == "close": - pipe.send((None, True)) - break - elif command == "_call": - name, args, kwargs = data - if name in ["reset", "step", "seed", "close"]: - raise ValueError( - f"Trying to call function `{name}` with " - f"`_call`. Use `{name}` directly instead." - ) - function = getattr(env, name) - if callable(function): - pipe.send((function(*args, **kwargs), True)) - else: - pipe.send((function, True)) - elif command == "_setattr": - name, value = data - setattr(env, name, value) - pipe.send((None, True)) - elif command == "_check_spaces": - pipe.send( - ( - (data[0] == observation_space, data[1] == action_space), - True, - ) - ) - else: - raise RuntimeError( - f"Received unknown command `{command}`. Must " - "be one of {`reset`, `step`, `seed`, `close`, `_call`, " - "`_setattr`, `_check_spaces`}." - ) - except (KeyboardInterrupt, Exception): - error_queue.put((index,) + sys.exc_info()[:2]) - pipe.send((None, False)) - finally: - env.close() diff --git a/gymnasium/experimental/vector/sync_vector_env.py b/gymnasium/experimental/vector/sync_vector_env.py deleted file mode 100644 index 1021fd9fb..000000000 --- a/gymnasium/experimental/vector/sync_vector_env.py +++ /dev/null @@ -1,229 +0,0 @@ -"""A synchronous vector environment.""" -from __future__ import annotations - -from copy import deepcopy -from typing import Any, Callable, Iterator - -import numpy as np - -from gymnasium import Env -from gymnasium.experimental.vector.utils import ( - batch_space, - concatenate, - create_empty_array, - iterate, -) -from gymnasium.experimental.vector.vector_env import VectorEnv - - -__all__ = ["SyncVectorEnv"] - - -class SyncVectorEnv(VectorEnv): - """Vectorized environment that serially runs multiple environments. - - Example: - >>> import gymnasium as gym - >>> env = gym.vector.SyncVectorEnv([ - ... lambda: gym.make("Pendulum-v1", g=9.81), - ... lambda: gym.make("Pendulum-v1", g=1.62) - ... ]) - >>> env.reset(seed=42) - (array([[-0.14995256, 0.9886932 , -0.12224312], - [ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32), {}) - """ - - def __init__( - self, - env_fns: Iterator[Callable[[], Env]], - copy: bool = True, - ): - """Vectorized environment that serially runs multiple environments. - - Args: - env_fns: iterable of callable functions that create the environments. - copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations. - - Raises: - RuntimeError: If the observation space of some sub-environment does not match observation_space - (or, by default, the observation space of the first sub-environment). - """ - super().__init__() - self.env_fns = env_fns - self.envs = [env_fn() for env_fn in env_fns] - self.num_envs = len(self.envs) - self.copy = copy - self.metadata = self.envs[0].metadata - - self.spec = self.envs[0].spec - - self.single_observation_space = self.envs[0].observation_space - self.single_action_space = self.envs[0].action_space - - self.observation_space = batch_space( - self.single_observation_space, self.num_envs - ) - self.action_space = batch_space(self.single_action_space, self.num_envs) - - self._check_spaces() - self.observations = create_empty_array( - self.single_observation_space, n=self.num_envs, fn=np.zeros - ) - self._rewards = np.zeros((self.num_envs,), dtype=np.float64) - self._terminateds = np.zeros((self.num_envs,), dtype=np.bool_) - self._truncateds = np.zeros((self.num_envs,), dtype=np.bool_) - - def reset( - self, - seed: int | list[int] | None = None, - options: dict | None = None, - ): - """Waits for the calls triggered by :meth:`reset_async` to finish and returns the results. - - Args: - seed: The reset environment seed - options: Option information for the environment reset - - Returns: - The reset observation of the environment and reset information - """ - if seed is None: - seed = [None for _ in range(self.num_envs)] - if isinstance(seed, int): - seed = [seed + i for i in range(self.num_envs)] - assert len(seed) == self.num_envs - - self._terminateds[:] = False - self._truncateds[:] = False - observations = [] - infos = {} - for i, (env, single_seed) in enumerate(zip(self.envs, seed)): - kwargs = {} - if single_seed is not None: - kwargs["seed"] = single_seed - if options is not None: - kwargs["options"] = options - - observation, info = env.reset(**kwargs) - observations.append(observation) - infos = self._add_info(infos, info, i) - - self.observations = concatenate( - self.single_observation_space, observations, self.observations - ) - return (deepcopy(self.observations) if self.copy else self.observations), infos - - def step(self, actions): - """Steps through each of the environments returning the batched results. - - Returns: - The batched environment step results - """ - actions = iterate(self.action_space, actions) - - observations, infos = [], {} - for i, (env, action) in enumerate(zip(self.envs, actions)): - ( - observation, - self._rewards[i], - self._terminateds[i], - self._truncateds[i], - info, - ) = env.step(action) - - if self._terminateds[i] or self._truncateds[i]: - old_observation, old_info = observation, info - observation, info = env.reset() - info["final_observation"] = old_observation - info["final_info"] = old_info - observations.append(observation) - infos = self._add_info(infos, info, i) - self.observations = concatenate( - self.single_observation_space, observations, self.observations - ) - - return ( - deepcopy(self.observations) if self.copy else self.observations, - np.copy(self._rewards), - np.copy(self._terminateds), - np.copy(self._truncateds), - infos, - ) - - def call(self, name, *args, **kwargs) -> tuple: - """Calls the method with name and applies args and kwargs. - - Args: - name: The method name - *args: The method args - **kwargs: The method kwargs - - Returns: - Tuple of results - """ - results = [] - for env in self.envs: - function = getattr(env, name) - if callable(function): - results.append(function(*args, **kwargs)) - else: - results.append(function) - - return tuple(results) - - def get_attr(self, name: str): - """Get a property from each parallel environment. - - Args: - name (str): Name of the property to be get from each individual environment. - - Returns: - The property with name - """ - return self.call(name) - - def set_attr(self, name: str, values: list | tuple | Any): - """Sets an attribute of the sub-environments. - - Args: - name: The property name to change - values: Values of the property to be set to. If ``values`` is a list or - tuple, then it corresponds to the values for each individual - environment, otherwise, a single value is set for all environments. - - Raises: - ValueError: Values must be a list or tuple with length equal to the number of environments. - """ - if not isinstance(values, (list, tuple)): - values = [values for _ in range(self.num_envs)] - if len(values) != self.num_envs: - raise ValueError( - "Values must be a list or tuple with length equal to the " - f"number of environments. Got `{len(values)}` values for " - f"{self.num_envs} environments." - ) - - for env, value in zip(self.envs, values): - setattr(env, name, value) - - def close_extras(self, **kwargs): - """Close the environments.""" - [env.close() for env in self.envs] - - def _check_spaces(self) -> bool: - for env in self.envs: - if not (env.observation_space == self.single_observation_space): - raise RuntimeError( - "Some environments have an observation space different from " - f"`{self.single_observation_space}`. In order to batch observations, " - "the observation spaces from all environments must be equal." - ) - - if not (env.action_space == self.single_action_space): - raise RuntimeError( - "Some environments have an action space different from " - f"`{self.single_action_space}`. In order to batch actions, the " - "action spaces from all environments must be equal." - ) - - return True diff --git a/gymnasium/experimental/vector/utils/__init__.py b/gymnasium/experimental/vector/utils/__init__.py deleted file mode 100644 index 4fae88a76..000000000 --- a/gymnasium/experimental/vector/utils/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Module for gymnasium experimental vector utility functions.""" - -from gymnasium.experimental.vector.utils.misc import ( - CloudpickleWrapper, - clear_mpi_env_vars, -) -from gymnasium.experimental.vector.utils.shared_memory import ( - create_shared_memory, - read_from_shared_memory, - write_to_shared_memory, -) -from gymnasium.experimental.vector.utils.space_utils import ( - batch_space, - concatenate, - create_empty_array, - iterate, -) - - -__all__ = [ - "batch_space", - "iterate", - "concatenate", - "create_empty_array", - "create_shared_memory", - "read_from_shared_memory", - "write_to_shared_memory", - "CloudpickleWrapper", - "clear_mpi_env_vars", -] diff --git a/gymnasium/experimental/vector/utils/misc.py b/gymnasium/experimental/vector/utils/misc.py deleted file mode 100644 index c8cd1f368..000000000 --- a/gymnasium/experimental/vector/utils/misc.py +++ /dev/null @@ -1,61 +0,0 @@ -"""Miscellaneous utilities.""" -from __future__ import annotations - -import contextlib -import os -from collections.abc import Callable - -from gymnasium.core import Env - - -__all__ = ["CloudpickleWrapper", "clear_mpi_env_vars"] - - -class CloudpickleWrapper: - """Wrapper that uses cloudpickle to pickle and unpickle the result.""" - - def __init__(self, fn: Callable[[], Env]): - """Cloudpickle wrapper for a function.""" - self.fn = fn - - def __getstate__(self): - """Get the state using `cloudpickle.dumps(self.fn)`.""" - import cloudpickle - - return cloudpickle.dumps(self.fn) - - def __setstate__(self, ob): - """Sets the state with obs.""" - import pickle - - self.fn = pickle.loads(ob) - - def __call__(self): - """Calls the function `self.fn` with no arguments.""" - return self.fn() - - -@contextlib.contextmanager -def clear_mpi_env_vars(): - """Clears the MPI of environment variables. - - `from mpi4py import MPI` will call `MPI_Init` by default. - If the child process has MPI environment variables, MPI will think that the child process - is an MPI process just like the parent and do bad things such as hang. - - This context manager is a hacky way to clear those environment variables - temporarily such as when we are starting multiprocessing Processes. - - Yields: - Yields for the context manager - """ - removed_environment = {} - for k, v in list(os.environ.items()): - for prefix in ["OMPI_", "PMI_"]: - if k.startswith(prefix): - removed_environment[k] = v - del os.environ[k] - try: - yield - finally: - os.environ.update(removed_environment) diff --git a/gymnasium/experimental/vector/utils/shared_memory.py b/gymnasium/experimental/vector/utils/shared_memory.py deleted file mode 100644 index 6f4d9c8a1..000000000 --- a/gymnasium/experimental/vector/utils/shared_memory.py +++ /dev/null @@ -1,255 +0,0 @@ -"""Utility functions for vector environments to share memory between processes.""" -from __future__ import annotations - -import multiprocessing as mp -from collections import OrderedDict -from ctypes import c_bool -from functools import singledispatch -from typing import Any - -import numpy as np - -from gymnasium.error import CustomSpaceError -from gymnasium.spaces import ( - Box, - Dict, - Discrete, - Graph, - MultiBinary, - MultiDiscrete, - Sequence, - Space, - Text, - Tuple, - flatten, -) - - -__all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_memory"] - - -@singledispatch -def create_shared_memory( - space: Space[Any], n: int = 1, ctx=mp -) -> dict[str, Any] | tuple[Any, ...] | mp.Array: - """Create a shared memory object, to be shared across processes. - - This eventually contains the observations from the vectorized environment. - - Args: - space: Observation space of a single environment in the vectorized environment. - n: Number of environments in the vectorized environment (i.e. the number of processes). - ctx: The multiprocess module - - Returns: - shared_memory for the shared object across processes. - - Raises: - CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance - """ - if isinstance(space, Space): - raise CustomSpaceError( - f"Space of type `{type(space)}` doesn't have an registered `create_shared_memory` function. Register `{type(space)}` for `create_shared_memory` to support it." - ) - else: - raise TypeError( - f"The space provided to `create_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}" - ) - - -@create_shared_memory.register(Box) -@create_shared_memory.register(Discrete) -@create_shared_memory.register(MultiDiscrete) -@create_shared_memory.register(MultiBinary) -def _create_base_shared_memory( - space: Box | Discrete | MultiDiscrete | MultiBinary, n: int = 1, ctx=mp -): - assert space.dtype is not None - dtype = space.dtype.char - if dtype in "?": - dtype = c_bool - return ctx.Array(dtype, n * int(np.prod(space.shape))) - - -@create_shared_memory.register(Tuple) -def _create_tuple_shared_memory(space: Tuple, n: int = 1, ctx=mp): - return tuple( - create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces - ) - - -@create_shared_memory.register(Dict) -def _create_dict_shared_memory(space: Dict, n: int = 1, ctx=mp): - return OrderedDict( - [ - (key, create_shared_memory(subspace, n=n, ctx=ctx)) - for (key, subspace) in space.spaces.items() - ] - ) - - -@create_shared_memory.register(Text) -def _create_text_shared_memory(space: Text, n: int = 1, ctx=mp): - return ctx.Array(np.dtype(np.int32).char, n * space.max_length) - - -@create_shared_memory.register(Graph) -@create_shared_memory.register(Sequence) -def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp): - raise TypeError( - f"As {space} has a dynamic shape then it is not possible to make a static shared memory." - ) - - -@singledispatch -def read_from_shared_memory( - space: Space, shared_memory: dict | tuple | mp.Array, n: int = 1 -) -> dict[str, Any] | tuple[Any, ...] | np.ndarray: - """Read the batch of observations from shared memory as a numpy array. - - ..notes:: - The numpy array objects returned by `read_from_shared_memory` shares the - memory of `shared_memory`. Any changes to `shared_memory` are forwarded - to `observations`, and vice-versa. To avoid any side-effect, use `np.copy`. - - Args: - space: Observation space of a single environment in the vectorized environment. - shared_memory: Shared object across processes. This contains the observations from the vectorized environment. - This object is created with `create_shared_memory`. - n: Number of environments in the vectorized environment (i.e. the number of processes). - - Returns: - Batch of observations as a (possibly nested) numpy array. - - Raises: - CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance - """ - if isinstance(space, Space): - raise CustomSpaceError( - f"Space of type `{type(space)}` doesn't have an registered `read_from_shared_memory` function. Register `{type(space)}` for `read_from_shared_memory` to support it." - ) - else: - raise TypeError( - f"The space provided to `read_from_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}" - ) - - -@read_from_shared_memory.register(Box) -@read_from_shared_memory.register(Discrete) -@read_from_shared_memory.register(MultiDiscrete) -@read_from_shared_memory.register(MultiBinary) -def _read_base_from_shared_memory( - space: Box | Discrete | MultiDiscrete | MultiBinary, shared_memory, n: int = 1 -): - return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape( - (n,) + space.shape - ) - - -@read_from_shared_memory.register(Tuple) -def _read_tuple_from_shared_memory(space: Tuple, shared_memory, n: int = 1): - return tuple( - read_from_shared_memory(subspace, memory, n=n) - for (memory, subspace) in zip(shared_memory, space.spaces) - ) - - -@read_from_shared_memory.register(Dict) -def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1): - return OrderedDict( - [ - (key, read_from_shared_memory(subspace, shared_memory[key], n=n)) - for (key, subspace) in space.spaces.items() - ] - ) - - -@read_from_shared_memory.register(Text) -def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tuple[str]: - data = np.frombuffer(shared_memory.get_obj(), dtype=np.int32).reshape( - (n, space.max_length) - ) - - return tuple( - "".join( - [ - space.character_list[val] - for val in values - if val < len(space.character_set) - ] - ) - for values in data - ) - - -@singledispatch -def write_to_shared_memory( - space: Space, - index: int, - value: np.ndarray, - shared_memory: dict[str, Any] | tuple[Any, ...] | mp.Array, -): - """Write the observation of a single environment into shared memory. - - Args: - space: Observation space of a single environment in the vectorized environment. - index: Index of the environment (must be in `[0, num_envs)`). - value: Observation of the single environment to write to shared memory. - shared_memory: Shared object across processes. This contains the observations from the vectorized environment. - This object is created with `create_shared_memory`. - - Raises: - CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance - """ - if isinstance(space, Space): - raise CustomSpaceError( - f"Space of type `{type(space)}` doesn't have an registered `write_to_shared_memory` function. Register `{type(space)}` for `write_to_shared_memory` to support it." - ) - else: - raise TypeError( - f"The space provided to `write_to_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}" - ) - - -@write_to_shared_memory.register(Box) -@write_to_shared_memory.register(Discrete) -@write_to_shared_memory.register(MultiDiscrete) -@write_to_shared_memory.register(MultiBinary) -def _write_base_to_shared_memory( - space: Box | Discrete | MultiDiscrete | MultiBinary, - index: int, - value, - shared_memory, -): - size = int(np.prod(space.shape)) - destination = np.frombuffer(shared_memory.get_obj(), dtype=space.dtype) - np.copyto( - destination[index * size : (index + 1) * size], - np.asarray(value, dtype=space.dtype).flatten(), - ) - - -@write_to_shared_memory.register(Tuple) -def _write_tuple_to_shared_memory( - space: Tuple, index: int, values: tuple[Any, ...], shared_memory -): - for value, memory, subspace in zip(values, shared_memory, space.spaces): - write_to_shared_memory(subspace, index, value, memory) - - -@write_to_shared_memory.register(Dict) -def _write_dict_to_shared_memory( - space: Dict, index: int, values: dict[str, Any], shared_memory -): - for key, subspace in space.spaces.items(): - write_to_shared_memory(subspace, index, values[key], shared_memory[key]) - - -@write_to_shared_memory.register(Text) -def _write_text_to_shared_memory(space: Text, index: int, values: str, shared_memory): - size = space.max_length - destination = np.frombuffer(shared_memory.get_obj(), dtype=np.int32) - np.copyto( - destination[index * size : (index + 1) * size], - flatten(space, values), - ) diff --git a/gymnasium/experimental/vector/vector_env.py b/gymnasium/experimental/vector/vector_env.py deleted file mode 100644 index d40a96df2..000000000 --- a/gymnasium/experimental/vector/vector_env.py +++ /dev/null @@ -1,486 +0,0 @@ -"""Base class for vectorized environments.""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Generic, TypeVar - -import numpy as np - -import gymnasium as gym -from gymnasium.core import ActType, ObsType -from gymnasium.utils import seeding - - -if TYPE_CHECKING: - from gymnasium.envs.registration import EnvSpec - -ArrayType = TypeVar("ArrayType") - - -__all__ = [ - "VectorEnv", - "VectorWrapper", - "VectorObservationWrapper", - "VectorActionWrapper", - "VectorRewardWrapper", - "ArrayType", -] - - -class VectorEnv(Generic[ObsType, ActType, ArrayType]): - """Base class for vectorized environments to run multiple independent copies of the same environment in parallel. - - Vector environments can provide a linear speed-up in the steps taken per second through sampling multiple - sub-environments at the same time. To prevent terminated environments waiting until all sub-environments have - terminated or truncated, the vector environments autoreset sub-environments after they terminate or truncated. - As a result, the final step's observation and info are overwritten by the reset's observation and info. - Therefore, the observation and info for the final step of a sub-environment is stored in the info parameter, - using `"final_observation"` and `"final_info"` respectively. See :meth:`step` for more information. - - The vector environments batch `observations`, `rewards`, `terminations`, `truncations` and `info` for each - parallel environment. In addition, :meth:`step` expects to receive a batch of actions for each parallel environment. - - Gymnasium contains two types of Vector environments: :class:`AsyncVectorEnv` and :class:`SyncVectorEnv`. - - The Vector Environments have the additional attributes for users to understand the implementation - - - :attr:`num_envs` - The number of sub-environment in the vector environment - - :attr:`observation_space` - The batched observation space of the vector environment - - :attr:`single_observation_space` - The observation space of a single sub-environment - - :attr:`action_space` - The batched action space of the vector environment - - :attr:`single_action_space` - The action space of a single sub-environment - - Note: - The info parameter of :meth:`reset` and :meth:`step` was originally implemented before OpenAI Gym v25 was a list - of dictionary for each sub-environment. However, this was modified in OpenAI Gym v25+ and in Gymnasium to a - dictionary with a NumPy array for each key. To use the old info style using the :class:`VectorListInfo`. - - Note: - To render the sub-environments, use :meth:`call` with "render" arguments. Remember to set the `render_modes` - for all the sub-environments during initialization. - - Note: - All parallel environments should share the identical observation and action spaces. - In other words, a vector of multiple different environments is not supported. - """ - - spec: EnvSpec - - observation_space: gym.Space - action_space: gym.Space - single_observation_space: gym.Space - single_action_space: gym.Space - - num_envs: int - - closed = False - - _np_random: np.random.Generator | None = None - - def reset( - self, - *, - seed: int | list[int] | None = None, - options: dict[str, Any] | None = None, - ) -> tuple[ObsType, dict[str, Any]]: # type: ignore - """Reset all parallel environments and return a batch of initial observations and info. - - Args: - seed: The environment reset seeds - options: If to return the options - - Returns: - A batch of observations and info from the vectorized environment. - - Example: - >>> import gymnasium as gym - >>> envs = gym.vector.make("CartPole-v1", num_envs=3) - >>> envs.reset(seed=42) - (array([[ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], - [ 0.01522993, -0.04562247, -0.04799704, 0.03392126], - [-0.03774345, -0.02418869, -0.00942293, 0.0469184 ]], - dtype=float32), {}) - """ - if seed is not None: - self._np_random, seed = seeding.np_random(seed) - - def step( - self, actions: ActType - ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]: - """Take an action for each parallel environment. - - Args: - actions: element of :attr:`action_space` Batch of actions. - - Returns: - Batch of (observations, rewards, terminations, truncations, infos) - - Note: - As the vector environments autoreset for a terminating and truncating sub-environments, - the returned observation and info is not the final step's observation or info which is instead stored in - info as `"final_observation"` and `"final_info"`. - - Example: - >>> import gymnasium as gym - >>> import numpy as np - >>> envs = gym.vector.make("CartPole-v1", num_envs=3) - >>> _ = envs.reset(seed=42) - >>> actions = np.array([1, 0, 1]) - >>> observations, rewards, termination, truncation, infos = envs.step(actions) - >>> observations - array([[ 0.02727336, 0.18847767, 0.03625453, -0.26141977], - [ 0.01431748, -0.24002443, -0.04731862, 0.3110827 ], - [-0.03822722, 0.1710671 , -0.00848456, -0.2487226 ]], - dtype=float32) - >>> rewards - array([1., 1., 1.]) - >>> termination - array([False, False, False]) - >>> termination - array([False, False, False]) - >>> infos - {} - """ - pass - - def close_extras(self, **kwargs): - """Clean up the extra resources e.g. beyond what's in this base class.""" - pass - - def close(self, **kwargs): - """Close all parallel environments and release resources. - - It also closes all the existing image viewers, then calls :meth:`close_extras` and set - :attr:`closed` as ``True``. - - Warnings: - This function itself does not close the environments, it should be handled - in :meth:`close_extras`. This is generic for both synchronous and asynchronous - vectorized environments. - - Note: - This will be automatically called when garbage collected or program exited. - - Args: - **kwargs: Keyword arguments passed to :meth:`close_extras` - """ - if self.closed: - return - - self.close_extras(**kwargs) - self.closed = True - - @property - def np_random(self) -> np.random.Generator: - """Returns the environment's internal :attr:`_np_random` that if not set will initialise with a random seed. - - Returns: - Instances of `np.random.Generator` - """ - if self._np_random is None: - self._np_random, seed = seeding.np_random() - return self._np_random - - @np_random.setter - def np_random(self, value: np.random.Generator): - self._np_random = value - - @property - def unwrapped(self): - """Return the base environment.""" - return self - - def _add_info(self, infos: dict, info: dict, env_num: int) -> dict: - """Add env info to the info dictionary of the vectorized environment. - - Given the `info` of a single environment add it to the `infos` dictionary - which represents all the infos of the vectorized environment. - Every `key` of `info` is paired with a boolean mask `_key` representing - whether or not the i-indexed environment has this `info`. - - Args: - infos (dict): the infos of the vectorized environment - info (dict): the info coming from the single environment - env_num (int): the index of the single environment - - Returns: - infos (dict): the (updated) infos of the vectorized environment - - """ - for k in info.keys(): - if k not in infos: - info_array, array_mask = self._init_info_arrays(type(info[k])) - else: - info_array, array_mask = infos[k], infos[f"_{k}"] - - info_array[env_num], array_mask[env_num] = info[k], True - infos[k], infos[f"_{k}"] = info_array, array_mask - return infos - - def _init_info_arrays(self, dtype: type) -> tuple[np.ndarray, np.ndarray]: - """Initialize the info array. - - Initialize the info array. If the dtype is numeric - the info array will have the same dtype, otherwise - will be an array of `None`. Also, a boolean array - of the same length is returned. It will be used for - assessing which environment has info data. - - Args: - dtype (type): data type of the info coming from the env. - - Returns: - array (np.ndarray): the initialized info array. - array_mask (np.ndarray): the initialized boolean array. - - """ - if dtype in [int, float, bool] or issubclass(dtype, np.number): - array = np.zeros(self.num_envs, dtype=dtype) - else: - array = np.zeros(self.num_envs, dtype=object) - array[:] = None - array_mask = np.zeros(self.num_envs, dtype=bool) - return array, array_mask - - def __del__(self): - """Closes the vector environment.""" - if not getattr(self, "closed", True): - self.close() - - def __repr__(self) -> str: - """Returns a string representation of the vector environment. - - Returns: - A string containing the class name, number of environments and environment spec id - """ - if getattr(self, "spec", None) is None: - return f"{self.__class__.__name__}({self.num_envs})" - else: - return f"{self.__class__.__name__}({self.spec.id}, {self.num_envs})" - - -class VectorWrapper(VectorEnv): - """Wraps the vectorized environment to allow a modular transformation. - - This class is the base class for all wrappers for vectorized environments. The subclass - could override some methods to change the behavior of the original vectorized environment - without touching the original code. - - Note: - Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`. - """ - - _observation_space: gym.Space | None = None - _action_space: gym.Space | None = None - _single_observation_space: gym.Space | None = None - _single_action_space: gym.Space | None = None - - def __init__(self, env: VectorEnv): - """Initialize the vectorized environment wrapper.""" - super().__init__() - - assert isinstance(env, VectorEnv) - self.env = env - - # explicitly forward the methods defined in VectorEnv - # to self.env (instead of the base class) - - def reset( - self, - *, - seed: int | list[int] | None = None, - options: dict[str, Any] | None = None, - ) -> tuple[ObsType, dict[str, Any]]: - """Reset all environment using seed and options.""" - return self.env.reset(seed=seed, options=options) - - def step( - self, actions: ActType - ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]: - """Step all environments.""" - return self.env.step(actions) - - def close(self, **kwargs: Any): - """Close all environments.""" - return self.env.close(**kwargs) - - def close_extras(self, **kwargs: Any): - """Close all extra resources.""" - return self.env.close_extras(**kwargs) - - # implicitly forward all other methods and attributes to self.env - def __getattr__(self, name: str) -> Any: - """Forward all other attributes to the base environment.""" - if name.startswith("_"): - raise AttributeError(f"attempted to get missing private attribute '{name}'") - return getattr(self.env, name) - - @property - def unwrapped(self): - """Return the base non-wrapped environment.""" - return self.env.unwrapped - - def __repr__(self): - """Return the string representation of the vectorized environment.""" - return f"<{self.__class__.__name__}, {self.env}>" - - def __del__(self): - """Close the vectorized environment.""" - self.env.__del__() - - @property - def spec(self) -> EnvSpec | None: - """Gets the specification of the wrapped environment.""" - return self.env.spec - - @property - def observation_space(self) -> gym.Space: - """Gets the observation space of the vector environment.""" - if self._observation_space is None: - return self.env.observation_space - return self._observation_space - - @observation_space.setter - def observation_space(self, space: gym.Space): - """Sets the observation space of the vector environment.""" - self._observation_space = space - - @property - def action_space(self) -> gym.Space: - """Gets the action space of the vector environment.""" - if self._action_space is None: - return self.env.action_space - return self._action_space - - @action_space.setter - def action_space(self, space: gym.Space): - """Sets the action space of the vector environment.""" - self._action_space = space - - @property - def single_observation_space(self) -> gym.Space: - """Gets the single observation space of the vector environment.""" - if self._single_observation_space is None: - return self.env.single_observation_space - return self._single_observation_space - - @single_observation_space.setter - def single_observation_space(self, space: gym.Space): - """Sets the single observation space of the vector environment.""" - self._single_observation_space = space - - @property - def single_action_space(self) -> gym.Space: - """Gets the single action space of the vector environment.""" - if self._single_action_space is None: - return self.env.single_action_space - return self._single_action_space - - @single_action_space.setter - def single_action_space(self, space): - """Sets the single action space of the vector environment.""" - self._single_action_space = space - - @property - def num_envs(self) -> int: - """Gets the wrapped vector environment's num of the sub-environments.""" - return self.env.num_envs - - -class VectorObservationWrapper(VectorWrapper): - """Wraps the vectorized environment to allow a modular transformation of the observation. Equivalent to :class:`gym.ObservationWrapper` for vectorized environments.""" - - def reset( - self, - *, - seed: int | list[int] | None = None, - options: dict[str, Any] | None = None, - ) -> tuple[ObsType, dict[str, Any]]: - """Modifies the observation returned from the environment ``reset`` using the :meth:`observation`.""" - obs, info = self.env.reset(seed=seed, options=options) - return self.vector_observation(obs), info - - def step( - self, actions: ActType - ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]: - """Modifies the observation returned from the environment ``step`` using the :meth:`observation`.""" - observation, reward, termination, truncation, info = self.env.step(actions) - return ( - self.vector_observation(observation), - reward, - termination, - truncation, - self.update_final_obs(info), - ) - - def vector_observation(self, observation: ObsType) -> ObsType: - """Defines the vector observation transformation. - - Args: - observation: A vector observation from the environment - - Returns: - the transformed observation - """ - raise NotImplementedError - - def single_observation(self, observation: ObsType) -> ObsType: - """Defines the single observation transformation. - - Args: - observation: A single observation from the environment - - Returns: - The transformed observation - """ - raise NotImplementedError - - def update_final_obs(self, info: dict[str, Any]) -> dict[str, Any]: - """Updates the `final_obs` in the info using `single_observation`.""" - if "final_observation" in info: - for i, obs in enumerate(info["final_observation"]): - if obs is not None: - info["final_observation"][i] = self.single_observation(obs) - return info - - -class VectorActionWrapper(VectorWrapper): - """Wraps the vectorized environment to allow a modular transformation of the actions. Equivalent of :class:`~gym.ActionWrapper` for vectorized environments.""" - - def step( - self, actions: ActType - ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]: - """Steps through the environment using a modified action by :meth:`action`.""" - return self.env.step(self.actions(actions)) - - def actions(self, actions: ActType) -> ActType: - """Transform the actions before sending them to the environment. - - Args: - actions (ActType): the actions to transform - - Returns: - ActType: the transformed actions - """ - raise NotImplementedError - - -class VectorRewardWrapper(VectorWrapper): - """Wraps the vectorized environment to allow a modular transformation of the reward. Equivalent of :class:`~gym.RewardWrapper` for vectorized environments.""" - - def step( - self, actions: ActType - ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]: - """Steps through the environment returning a reward modified by :meth:`reward`.""" - observation, reward, termination, truncation, info = self.env.step(actions) - return observation, self.reward(reward), termination, truncation, info - - def reward(self, reward: ArrayType) -> ArrayType: - """Transform the reward before returning it. - - Args: - reward (array): the reward to transform - - Returns: - array: the transformed reward - """ - raise NotImplementedError diff --git a/gymnasium/experimental/wrappers/__init__.py b/gymnasium/experimental/wrappers/__init__.py deleted file mode 100644 index 5e57997f2..000000000 --- a/gymnasium/experimental/wrappers/__init__.py +++ /dev/null @@ -1,164 +0,0 @@ -"""`__init__` for experimental wrappers, to avoid loading the wrappers if unnecessary, we can hack python.""" -# pyright: reportUnsupportedDunderAll=false -import importlib -import re - -from gymnasium.error import DeprecatedWrapper -from gymnasium.experimental.wrappers import vector -from gymnasium.experimental.wrappers.atari_preprocessing import AtariPreprocessingV0 -from gymnasium.experimental.wrappers.common import ( - AutoresetV0, - OrderEnforcingV0, - PassiveEnvCheckerV0, - RecordEpisodeStatisticsV0, -) -from gymnasium.experimental.wrappers.lambda_action import ( - ClipActionV0, - LambdaActionV0, - RescaleActionV0, -) -from gymnasium.experimental.wrappers.lambda_observation import ( - DtypeObservationV0, - FilterObservationV0, - FlattenObservationV0, - GrayscaleObservationV0, - LambdaObservationV0, - PixelObservationV0, - RescaleObservationV0, - ReshapeObservationV0, - ResizeObservationV0, -) -from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0 -from gymnasium.experimental.wrappers.rendering import ( - HumanRenderingV0, - RecordVideoV0, - RenderCollectionV0, -) -from gymnasium.experimental.wrappers.stateful_action import StickyActionV0 -from gymnasium.experimental.wrappers.stateful_observation import ( - DelayObservationV0, - FrameStackObservationV0, - MaxAndSkipObservationV0, - NormalizeObservationV0, - TimeAwareObservationV0, -) -from gymnasium.experimental.wrappers.stateful_reward import NormalizeRewardV1 - - -# Todo - Add legacy wrapper to new wrapper error for users when merged into gymnasium.wrappers - - -__all__ = [ - "vector", - # --- Observation wrappers --- - "AtariPreprocessingV0", - "DelayObservationV0", - "DtypeObservationV0", - "FilterObservationV0", - "FlattenObservationV0", - "FrameStackObservationV0", - "GrayscaleObservationV0", - "LambdaObservationV0", - "MaxAndSkipObservationV0", - "NormalizeObservationV0", - "PixelObservationV0", - "ResizeObservationV0", - "ReshapeObservationV0", - "RescaleObservationV0", - "TimeAwareObservationV0", - # --- Action Wrappers --- - "ClipActionV0", - "LambdaActionV0", - "RescaleActionV0", - # "NanAction", - "StickyActionV0", - # --- Reward wrappers --- - "ClipRewardV0", - "LambdaRewardV0", - "NormalizeRewardV1", - # --- Common --- - "AutoresetV0", - "PassiveEnvCheckerV0", - "OrderEnforcingV0", - "RecordEpisodeStatisticsV0", - # --- Rendering --- - "RenderCollectionV0", - "RecordVideoV0", - "HumanRenderingV0", - # --- Conversion --- - "JaxToNumpyV0", - "JaxToTorchV0", - "NumpyToTorchV0", -] - -# As these wrappers requires `jax` or `torch`, they are loaded by runtime for users trying to access them -# to avoid `import jax` or `import torch` on `import gymnasium`. -_wrapper_to_class = { - # data converters - "JaxToNumpyV0": "jax_to_numpy", - "JaxToTorchV0": "jax_to_torch", - "NumpyToTorchV0": "numpy_to_torch", -} - - -def __getattr__(wrapper_name: str): - """Load a wrapper by name. - - This optimizes the loading of gymnasium wrappers by only loading the wrapper if it is used. - Errors will be raised if the wrapper does not exist or if the version is not the latest. - - Args: - wrapper_name: The name of a wrapper to load. - - Returns: - The specified wrapper. - - Raises: - AttributeError: If the wrapper does not exist. - DeprecatedWrapper: If the version is not the latest. - """ - # Check if the requested wrapper is in the _wrapper_to_class dictionary - if wrapper_name in _wrapper_to_class: - import_stmt = ( - f"gymnasium.experimental.wrappers.{_wrapper_to_class[wrapper_name]}" - ) - module = importlib.import_module(import_stmt) - return getattr(module, wrapper_name) - - # Define a regex pattern to match the integer suffix (version number) of the wrapper - int_suffix_pattern = r"(\d+)$" - version_match = re.search(int_suffix_pattern, wrapper_name) - - # If a version number is found, extract it and the base wrapper name - if version_match: - version = int(version_match.group()) - base_name = wrapper_name[: -len(version_match.group())] - else: - version = float("inf") - base_name = wrapper_name[:-2] - - # Filter the list of all wrappers to include only those with the same base name - matching_wrappers = [name for name in __all__ if name.startswith(base_name)] - - # If no matching wrappers are found, raise an AttributeError - if not matching_wrappers: - raise AttributeError(f"module {__name__!r} has no attribute {wrapper_name!r}") - - # Find the latest version of the matching wrappers - latest_wrapper = max( - matching_wrappers, key=lambda s: int(re.findall(int_suffix_pattern, s)[0]) - ) - latest_version = int(re.findall(int_suffix_pattern, latest_wrapper)[0]) - - # If the requested wrapper is an older version, raise a DeprecatedWrapper exception - if version < latest_version: - raise DeprecatedWrapper( - f"{wrapper_name!r} is now deprecated, use {latest_wrapper!r} instead.\n" - f"To see the changes made, go to " - f"https://gymnasium.farama.org/api/experimental/wrappers/#gymnasium.experimental.wrappers.{latest_wrapper}" - ) - # If the requested version is invalid, raise an AttributeError - else: - raise AttributeError( - f"module {__name__!r} has no attribute {wrapper_name!r}, did you mean {latest_wrapper!r}" - ) diff --git a/gymnasium/experimental/wrappers/atari_preprocessing.py b/gymnasium/experimental/wrappers/atari_preprocessing.py deleted file mode 100644 index aedf17fb3..000000000 --- a/gymnasium/experimental/wrappers/atari_preprocessing.py +++ /dev/null @@ -1,206 +0,0 @@ -"""Implementation of Atari 2600 Preprocessing following the guidelines of Machado et al., 2018.""" -import numpy as np - -import gymnasium as gym -from gymnasium.spaces import Box - - -__all__ = ["AtariPreprocessingV0"] - - -class AtariPreprocessingV0(gym.Wrapper, gym.utils.RecordConstructorArgs): - """Atari 2600 preprocessing wrapper. - - This class follows the guidelines in Machado et al. (2018), - "Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents". - - Specifically, the following preprocess stages applies to the atari environment: - - - Noop Reset: Obtains the initial state by taking a random number of no-ops on reset, default max 30 no-ops. - - Frame skipping: The number of frames skipped between steps, 4 by default - - Max-pooling: Pools over the most recent two observations from the frame skips - - Termination signal when a life is lost: When the agent losses a life during the environment, then the environment is terminated. - Turned off by default. Not recommended by Machado et al. (2018). - - Resize to a square image: Resizes the atari environment original observation shape from 210x180 to 84x84 by default - - Grayscale observation: If the observation is colour or greyscale, by default, greyscale. - - Scale observation: If to scale the observation between [0, 1) or [0, 255), by default, not scaled. - """ - - def __init__( - self, - env: gym.Env, - noop_max: int = 30, - frame_skip: int = 4, - screen_size: int = 84, - terminal_on_life_loss: bool = False, - grayscale_obs: bool = True, - grayscale_newaxis: bool = False, - scale_obs: bool = False, - ): - """Wrapper for Atari 2600 preprocessing. - - Args: - env (Env): The environment to apply the preprocessing - noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0. - frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game. - screen_size (int): resize Atari frame - terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a - life is lost. - grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation - is returned. - grayscale_newaxis (bool): `if True and grayscale_obs=True`, then a channel axis is added to - grayscale observations to make them 3-dimensional. - scale_obs (bool): if True, then observation normalized in range [0,1) is returned. It also limits memory - optimization benefits of FrameStack Wrapper. - - Raises: - DependencyNotInstalled: opencv-python package not installed - ValueError: Disable frame-skipping in the original env - """ - gym.utils.RecordConstructorArgs.__init__( - self, - noop_max=noop_max, - frame_skip=frame_skip, - screen_size=screen_size, - terminal_on_life_loss=terminal_on_life_loss, - grayscale_obs=grayscale_obs, - grayscale_newaxis=grayscale_newaxis, - scale_obs=scale_obs, - ) - gym.Wrapper.__init__(self, env) - - try: - import cv2 # noqa: F401 - except ImportError: - raise gym.error.DependencyNotInstalled( - "opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari" - ) - - assert frame_skip > 0 - assert screen_size > 0 - assert noop_max >= 0 - if frame_skip > 1: - if ( - env.spec is not None - and "NoFrameskip" not in env.spec.id - and getattr(env.unwrapped, "_frameskip", None) != 1 - ): - raise ValueError( - "Disable frame-skipping in the original env. Otherwise, more than one " - "frame-skip will happen as through this wrapper" - ) - self.noop_max = noop_max - assert env.unwrapped.get_action_meanings()[0] == "NOOP" - - self.frame_skip = frame_skip - self.screen_size = screen_size - self.terminal_on_life_loss = terminal_on_life_loss - self.grayscale_obs = grayscale_obs - self.grayscale_newaxis = grayscale_newaxis - self.scale_obs = scale_obs - - # buffer of most recent two observations for max pooling - assert isinstance(env.observation_space, Box) - if grayscale_obs: - self.obs_buffer = [ - np.empty(env.observation_space.shape[:2], dtype=np.uint8), - np.empty(env.observation_space.shape[:2], dtype=np.uint8), - ] - else: - self.obs_buffer = [ - np.empty(env.observation_space.shape, dtype=np.uint8), - np.empty(env.observation_space.shape, dtype=np.uint8), - ] - - self.lives = 0 - self.game_over = False - - _low, _high, _obs_dtype = ( - (0, 255, np.uint8) if not scale_obs else (0, 1, np.float32) - ) - _shape = (screen_size, screen_size, 1 if grayscale_obs else 3) - if grayscale_obs and not grayscale_newaxis: - _shape = _shape[:-1] # Remove channel axis - self.observation_space = Box( - low=_low, high=_high, shape=_shape, dtype=_obs_dtype - ) - - @property - def ale(self): - """Make ale as a class property to avoid serialization error.""" - return self.env.unwrapped.ale - - def step(self, action): - """Applies the preprocessing for an :meth:`env.step`.""" - total_reward, terminated, truncated, info = 0.0, False, False, {} - - for t in range(self.frame_skip): - _, reward, terminated, truncated, info = self.env.step(action) - total_reward += reward - self.game_over = terminated - - if self.terminal_on_life_loss: - new_lives = self.ale.lives() - terminated = terminated or new_lives < self.lives - self.game_over = terminated - self.lives = new_lives - - if terminated or truncated: - break - if t == self.frame_skip - 2: - if self.grayscale_obs: - self.ale.getScreenGrayscale(self.obs_buffer[1]) - else: - self.ale.getScreenRGB(self.obs_buffer[1]) - elif t == self.frame_skip - 1: - if self.grayscale_obs: - self.ale.getScreenGrayscale(self.obs_buffer[0]) - else: - self.ale.getScreenRGB(self.obs_buffer[0]) - return self._get_obs(), total_reward, terminated, truncated, info - - def reset(self, **kwargs): - """Resets the environment using preprocessing.""" - # NoopReset - _, reset_info = self.env.reset(**kwargs) - - noops = ( - self.env.unwrapped.np_random.integers(1, self.noop_max + 1) - if self.noop_max > 0 - else 0 - ) - for _ in range(noops): - _, _, terminated, truncated, step_info = self.env.step(0) - reset_info.update(step_info) - if terminated or truncated: - _, reset_info = self.env.reset(**kwargs) - - self.lives = self.ale.lives() - if self.grayscale_obs: - self.ale.getScreenGrayscale(self.obs_buffer[0]) - else: - self.ale.getScreenRGB(self.obs_buffer[0]) - self.obs_buffer[1].fill(0) - - return self._get_obs(), reset_info - - def _get_obs(self): - if self.frame_skip > 1: # more efficient in-place pooling - np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0]) - - import cv2 - - obs = cv2.resize( - self.obs_buffer[0], - (self.screen_size, self.screen_size), - interpolation=cv2.INTER_AREA, - ) - - if self.scale_obs: - obs = np.asarray(obs, dtype=np.float32) / 255.0 - else: - obs = np.asarray(obs, dtype=np.uint8) - - if self.grayscale_obs and self.grayscale_newaxis: - obs = np.expand_dims(obs, axis=-1) # Add a channel axis - return obs diff --git a/gymnasium/experimental/wrappers/common.py b/gymnasium/experimental/wrappers/common.py deleted file mode 100644 index c5e85dc8a..000000000 --- a/gymnasium/experimental/wrappers/common.py +++ /dev/null @@ -1,315 +0,0 @@ -"""A collection of common wrappers. - -* ``AutoresetV0`` - Auto-resets the environment -* ``PassiveEnvCheckerV0`` - Passive environment checker that does not modify any environment data -* ``OrderEnforcingV0`` - Enforces the order of function calls to environments -* ``RecordEpisodeStatisticsV0`` - Records the episode statistics -""" -from __future__ import annotations - -import time -from collections import deque -from typing import Any, SupportsFloat - -import numpy as np - -import gymnasium as gym -from gymnasium.core import ActType, ObsType, RenderFrame -from gymnasium.error import ResetNeeded -from gymnasium.utils.passive_env_checker import ( - check_action_space, - check_observation_space, - env_render_passive_checker, - env_reset_passive_checker, - env_step_passive_checker, -) - - -__all__ = [ - "AutoresetV0", - "PassiveEnvCheckerV0", - "OrderEnforcingV0", - "RecordEpisodeStatisticsV0", -] - - -class AutoresetV0( - gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs -): - """A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`.""" - - def __init__(self, env: gym.Env[ObsType, ActType]): - """A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`. - - Args: - env (gym.Env): The environment to apply the wrapper - """ - gym.utils.RecordConstructorArgs.__init__(self) - gym.Wrapper.__init__(self, env) - - self._episode_ended: bool = False - self._reset_options: dict[str, Any] | None = None - - def step( - self, action: ActType - ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: - """Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered in the previous step. - - Args: - action: The action to take - - Returns: - The autoreset environment :meth:`step` - """ - if self._episode_ended: - obs, info = self.env.reset(options=self._reset_options) - self._episode_ended = True - return obs, 0, False, False, info - else: - obs, reward, terminated, truncated, info = super().step(action) - self._episode_ended = terminated or truncated - return obs, reward, terminated, truncated, info - - def reset( - self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[ObsType, dict[str, Any]]: - """Resets the environment, saving the options used.""" - self._episode_ended = False - self._reset_options = options - return super().reset(seed=seed, options=self._reset_options) - - -class PassiveEnvCheckerV0( - gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs -): - """A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API.""" - - def __init__(self, env: gym.Env[ObsType, ActType]): - """Initialises the wrapper with the environments, run the observation and action space tests.""" - gym.utils.RecordConstructorArgs.__init__(self) - gym.Wrapper.__init__(self, env) - - assert hasattr( - env, "action_space" - ), "The environment must specify an action space. https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/" - check_action_space(env.action_space) - assert hasattr( - env, "observation_space" - ), "The environment must specify an observation space. https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/" - check_observation_space(env.observation_space) - - self._checked_reset: bool = False - self._checked_step: bool = False - self._checked_render: bool = False - - def step( - self, action: ActType - ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: - """Steps through the environment that on the first call will run the `passive_env_step_check`.""" - if self._checked_step is False: - self._checked_step = True - return env_step_passive_checker(self.env, action) - else: - return self.env.step(action) - - def reset( - self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[ObsType, dict[str, Any]]: - """Resets the environment that on the first call will run the `passive_env_reset_check`.""" - if self._checked_reset is False: - self._checked_reset = True - return env_reset_passive_checker(self.env, seed=seed, options=options) - else: - return self.env.reset(seed=seed, options=options) - - def render(self) -> RenderFrame | list[RenderFrame] | None: - """Renders the environment that on the first call will run the `passive_env_render_check`.""" - if self._checked_render is False: - self._checked_render = True - return env_render_passive_checker(self.env) - else: - return self.env.render() - - -class OrderEnforcingV0( - gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs -): - """A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`. - - Example: - >>> import gymnasium as gym - >>> from gymnasium.experimental.wrappers import OrderEnforcingV0 - >>> env = gym.make("CartPole-v1", render_mode="human") - >>> env = OrderEnforcingV0(env) - >>> env.step(0) - Traceback (most recent call last): - ... - gymnasium.error.ResetNeeded: Cannot call env.step() before calling env.reset() - >>> env.render() - Traceback (most recent call last): - ... - gymnasium.error.ResetNeeded: Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper. - >>> _ = env.reset() - >>> env.render() - >>> _ = env.step(0) - >>> env.close() - """ - - def __init__( - self, - env: gym.Env[ObsType, ActType], - disable_render_order_enforcing: bool = False, - ): - """A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`. - - Args: - env: The environment to wrap - disable_render_order_enforcing: If to disable render order enforcing - """ - gym.utils.RecordConstructorArgs.__init__( - self, disable_render_order_enforcing=disable_render_order_enforcing - ) - gym.Wrapper.__init__(self, env) - - self._has_reset: bool = False - self._disable_render_order_enforcing: bool = disable_render_order_enforcing - - def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict]: - """Steps through the environment.""" - if not self._has_reset: - raise ResetNeeded("Cannot call env.step() before calling env.reset()") - return super().step(action) - - def reset( - self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[ObsType, dict[str, Any]]: - """Resets the environment with `kwargs`.""" - self._has_reset = True - return super().reset(seed=seed, options=options) - - def render(self) -> RenderFrame | list[RenderFrame] | None: - """Renders the environment with `kwargs`.""" - if not self._disable_render_order_enforcing and not self._has_reset: - raise ResetNeeded( - "Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, " - "set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper." - ) - return super().render() - - @property - def has_reset(self): - """Returns if the environment has been reset before.""" - return self._has_reset - - -class RecordEpisodeStatisticsV0( - gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs -): - """This wrapper will keep track of cumulative rewards and episode lengths. - - At the end of an episode, the statistics of the episode will be added to ``info`` - using the key ``episode``. If using a vectorized environment also the key - ``_episode`` is used which indicates whether the env at the respective index has - the episode statistics. - - After the completion of an episode, ``info`` will look like this:: - - >>> info = { - ... "episode": { - ... "r": "", - ... "l": "", - ... "t": "" - ... }, - ... } - - For a vectorized environments the output will be in the form of:: - - >>> infos = { - ... "final_observation": "", - ... "_final_observation": "", - ... "final_info": "", - ... "_final_info": "", - ... "episode": { - ... "r": "", - ... "l": "", - ... "t": "" - ... }, - ... "_episode": "" - ... } - - - Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via - :attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively. - - Attributes: - episode_reward_buffer: The cumulative rewards of the last ``deque_size``-many episodes - episode_length_buffer: The lengths of the last ``deque_size``-many episodes - """ - - def __init__( - self, - env: gym.Env[ObsType, ActType], - buffer_length: int | None = 100, - stats_key: str = "episode", - ): - """This wrapper will keep track of cumulative rewards and episode lengths. - - Args: - env (Env): The environment to apply the wrapper - buffer_length: The size of the buffers :attr:`return_queue` and :attr:`length_queue` - stats_key: The info key for the episode statistics - """ - gym.utils.RecordConstructorArgs.__init__(self) - gym.Wrapper.__init__(self, env) - - self._stats_key = stats_key - - self.episode_count = 0 - self.episode_start_time: float = -1 - self.episode_reward: float = -1 - self.episode_length: int = -1 - - self.episode_time_length_buffer: deque[int] = deque(maxlen=buffer_length) - self.episode_reward_buffer: deque[float] = deque(maxlen=buffer_length) - self.episode_length_buffer: deque[int] = deque(maxlen=buffer_length) - - def step( - self, action: ActType - ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: - """Steps through the environment, recording the episode statistics.""" - obs, reward, terminated, truncated, info = super().step(action) - - self.episode_reward += reward - self.episode_length += 1 - - if terminated or truncated: - assert self._stats_key not in info - - episode_time_length = np.round( - time.perf_counter() - self.episode_start_time, 6 - ) - info[self._stats_key] = { - "r": self.episode_reward, - "l": self.episode_length, - "t": episode_time_length, - } - - self.episode_time_length_buffer.append(episode_time_length) - self.episode_reward_buffer.append(self.episode_reward) - self.episode_length_buffer.append(self.episode_length) - - self.episode_count += 1 - - return obs, reward, terminated, truncated, info - - def reset( - self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[ObsType, dict[str, Any]]: - """Resets the environment using seed and options and resets the episode rewards and lengths.""" - obs, info = super().reset(seed=seed, options=options) - - self.episode_start_time = time.perf_counter() - self.episode_reward = 0 - self.episode_length = 0 - - return obs, info diff --git a/gymnasium/experimental/wrappers/lambda_observation.py b/gymnasium/experimental/wrappers/lambda_observation.py deleted file mode 100644 index 36f0d70fc..000000000 --- a/gymnasium/experimental/wrappers/lambda_observation.py +++ /dev/null @@ -1,620 +0,0 @@ -"""A collection of observation wrappers using a lambda function. - -* ``LambdaObservationV0`` - Transforms the observation with a function -* ``FilterObservationV0`` - Filters a ``Tuple`` or ``Dict`` to only include certain keys -* ``FlattenObservationV0`` - Flattens the observations -* ``GrayscaleObservationV0`` - Converts a RGB observation to a grayscale observation -* ``ResizeObservationV0`` - Resizes an array-based observation (normally a RGB observation) -* ``ReshapeObservationV0`` - Reshapes an array-based observation -* ``RescaleObservationV0`` - Rescales an observation to between a minimum and maximum value -* ``DtypeObservationV0`` - Convert an observation to a dtype -* ``PixelObservationV0`` - Allows the observation to the rendered frame -""" -from __future__ import annotations - -from typing import Any, Callable, Final, Sequence - -import numpy as np - -import gymnasium as gym -from gymnasium import spaces -from gymnasium.core import ActType, ObsType, WrapperObsType -from gymnasium.error import DependencyNotInstalled - - -__all__ = [ - "LambdaObservationV0", - "FilterObservationV0", - "FlattenObservationV0", - "GrayscaleObservationV0", - "ResizeObservationV0", - "ReshapeObservationV0", - "RescaleObservationV0", - "DtypeObservationV0", - "PixelObservationV0", -] - - -class LambdaObservationV0( - gym.ObservationWrapper[WrapperObsType, ActType, ObsType], - gym.utils.RecordConstructorArgs, -): - """Transforms an observation via a function provided to the wrapper. - - The function :attr:`func` will be applied to all observations. - If the observations from :attr:`func` are outside the bounds of the ``env``'s observation space, provide an :attr:`observation_space`. - - Example: - >>> import gymnasium as gym - >>> from gymnasium.experimental.wrappers import LambdaObservationV0 - >>> import numpy as np - >>> np.random.seed(0) - >>> env = gym.make("CartPole-v1") - >>> env = LambdaObservationV0(env, lambda obs: obs + 0.1 * np.random.random(obs.shape), env.observation_space) - >>> env.reset(seed=42) - (array([0.08227695, 0.06540678, 0.09613613, 0.07422512]), {}) - """ - - def __init__( - self, - env: gym.Env[ObsType, ActType], - func: Callable[[ObsType], Any], - observation_space: gym.Space[WrapperObsType] | None, - ): - """Constructor for the lambda observation wrapper. - - Args: - env: The environment to wrap - func: A function that will transform an observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an `observation_space`. - observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``. - """ - gym.utils.RecordConstructorArgs.__init__( - self, func=func, observation_space=observation_space - ) - gym.ObservationWrapper.__init__(self, env) - - if observation_space is not None: - self.observation_space = observation_space - - self.func = func - - def observation(self, observation: ObsType) -> Any: - """Apply function to the observation.""" - return self.func(observation) - - -class FilterObservationV0( - LambdaObservationV0[WrapperObsType, ActType, ObsType], - gym.utils.RecordConstructorArgs, -): - """Filters Dict or Tuple observation space by the keys or indexes. - - Example: - >>> import gymnasium as gym - >>> from gymnasium.wrappers import TransformObservation - >>> from gymnasium.experimental.wrappers import FilterObservationV0 - >>> env = gym.make("CartPole-v1") - >>> env = gym.wrappers.TransformObservation(env, lambda obs: {'obs': obs, 'time': 0}) - >>> env.observation_space = gym.spaces.Dict(obs=env.observation_space, time=gym.spaces.Discrete(1)) - >>> env.reset(seed=42) - ({'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), 'time': 0}, {}) - >>> env = FilterObservationV0(env, filter_keys=['time']) - >>> env.reset(seed=42) - ({'time': 0}, {}) - >>> env.step(0) - ({'time': 0}, 1.0, False, False, {}) - """ - - def __init__( - self, env: gym.Env[ObsType, ActType], filter_keys: Sequence[str | int] - ): - """Constructor for the filter observation wrapper. - - Args: - env: The environment to wrap - filter_keys: The subspaces to be included, use a list of strings or integers for ``Dict`` and ``Tuple`` spaces respectivesly - """ - assert isinstance(filter_keys, Sequence) - gym.utils.RecordConstructorArgs.__init__(self, filter_keys=filter_keys) - - # Filters for dictionary space - if isinstance(env.observation_space, spaces.Dict): - assert all(isinstance(key, str) for key in filter_keys) - - if any( - key not in env.observation_space.spaces.keys() for key in filter_keys - ): - missing_keys = [ - key - for key in filter_keys - if key not in env.observation_space.spaces.keys() - ] - raise ValueError( - "All the `filter_keys` must be included in the observation space.\n" - f"Filter keys: {filter_keys}\n" - f"Observation keys: {list(env.observation_space.spaces.keys())}\n" - f"Missing keys: {missing_keys}" - ) - - new_observation_space = spaces.Dict( - {key: env.observation_space[key] for key in filter_keys} - ) - if len(new_observation_space) == 0: - raise ValueError( - "The observation space is empty due to filtering all keys." - ) - - LambdaObservationV0.__init__( - self, - env=env, - func=lambda obs: {key: obs[key] for key in filter_keys}, - observation_space=new_observation_space, - ) - # Filter for tuple observation - elif isinstance(env.observation_space, spaces.Tuple): - assert all(isinstance(key, int) for key in filter_keys) - assert len(set(filter_keys)) == len( - filter_keys - ), f"Duplicate keys exist, filter_keys: {filter_keys}" - - if any( - 0 < key and key >= len(env.observation_space) for key in filter_keys - ): - missing_index = [ - key - for key in filter_keys - if 0 < key and key >= len(env.observation_space) - ] - raise ValueError( - "All the `filter_keys` must be included in the length of the observation space.\n" - f"Filter keys: {filter_keys}, length of observation: {len(env.observation_space)}, " - f"missing indexes: {missing_index}" - ) - - new_observation_spaces = spaces.Tuple( - env.observation_space[key] for key in filter_keys - ) - if len(new_observation_spaces) == 0: - raise ValueError( - "The observation space is empty due to filtering all keys." - ) - - LambdaObservationV0.__init__( - self, - env=env, - func=lambda obs: tuple(obs[key] for key in filter_keys), - observation_space=new_observation_spaces, - ) - else: - raise ValueError( - f"FilterObservation wrapper is only usable with `Dict` and `Tuple` observations, actual type: {type(env.observation_space)}" - ) - - self.filter_keys: Final[Sequence[str | int]] = filter_keys - - -class FlattenObservationV0( - LambdaObservationV0[WrapperObsType, ActType, ObsType], - gym.utils.RecordConstructorArgs, -): - """Observation wrapper that flattens the observation. - - Example: - >>> import gymnasium as gym - >>> from gymnasium.experimental.wrappers import FlattenObservationV0 - >>> env = gym.make("CarRacing-v2") - >>> env.observation_space.shape - (96, 96, 3) - >>> env = FlattenObservationV0(env) - >>> env.observation_space.shape - (27648,) - >>> obs, _ = env.reset() - >>> obs.shape - (27648,) - """ - - def __init__(self, env: gym.Env[ObsType, ActType]): - """Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``. - - Args: - env: The environment to wrap - """ - gym.utils.RecordConstructorArgs.__init__(self) - LambdaObservationV0.__init__( - self, - env=env, - func=lambda obs: spaces.utils.flatten(env.observation_space, obs), - observation_space=spaces.utils.flatten_space(env.observation_space), - ) - - -class GrayscaleObservationV0( - LambdaObservationV0[WrapperObsType, ActType, ObsType], - gym.utils.RecordConstructorArgs, -): - """Observation wrapper that converts an RGB image to grayscale. - - The :attr:`keep_dim` will keep the channel dimension - - Example: - >>> import gymnasium as gym - >>> from gymnasium.experimental.wrappers import GrayscaleObservationV0 - >>> env = gym.make("CarRacing-v2") - >>> env.observation_space.shape - (96, 96, 3) - >>> grayscale_env = GrayscaleObservationV0(env) - >>> grayscale_env.observation_space.shape - (96, 96) - >>> grayscale_env = GrayscaleObservationV0(env, keep_dim=True) - >>> grayscale_env.observation_space.shape - (96, 96, 1) - """ - - def __init__(self, env: gym.Env[ObsType, ActType], keep_dim: bool = False): - """Constructor for an RGB image based environments to make the image grayscale. - - Args: - env: The environment to wrap - keep_dim: If to keep the channel in the observation, if ``True``, ``obs.shape == 3`` else ``obs.shape == 2`` - """ - assert isinstance(env.observation_space, spaces.Box) - assert ( - len(env.observation_space.shape) == 3 - and env.observation_space.shape[-1] == 3 - ) - assert ( - np.all(env.observation_space.low == 0) - and np.all(env.observation_space.high == 255) - and env.observation_space.dtype == np.uint8 - ) - gym.utils.RecordConstructorArgs.__init__(self, keep_dim=keep_dim) - - self.keep_dim: Final[bool] = keep_dim - if keep_dim: - new_observation_space = spaces.Box( - low=0, - high=255, - shape=env.observation_space.shape[:2] + (1,), - dtype=np.uint8, - ) - LambdaObservationV0.__init__( - self, - env=env, - func=lambda obs: np.expand_dims( - np.sum( - np.multiply(obs, np.array([0.2125, 0.7154, 0.0721])), axis=-1 - ).astype(np.uint8), - axis=-1, - ), - observation_space=new_observation_space, - ) - else: - new_observation_space = spaces.Box( - low=0, high=255, shape=env.observation_space.shape[:2], dtype=np.uint8 - ) - LambdaObservationV0.__init__( - self, - env=env, - func=lambda obs: np.sum( - np.multiply(obs, np.array([0.2125, 0.7154, 0.0721])), axis=-1 - ).astype(np.uint8), - observation_space=new_observation_space, - ) - - -class ResizeObservationV0( - LambdaObservationV0[WrapperObsType, ActType, ObsType], - gym.utils.RecordConstructorArgs, -): - """Resizes image observations using OpenCV to shape. - - Example: - >>> import gymnasium as gym - >>> from gymnasium.experimental.wrappers import ResizeObservationV0 - >>> env = gym.make("CarRacing-v2") - >>> env.observation_space.shape - (96, 96, 3) - >>> resized_env = ResizeObservationV0(env, (32, 32)) - >>> resized_env.observation_space.shape - (32, 32, 3) - """ - - def __init__(self, env: gym.Env[ObsType, ActType], shape: tuple[int, ...]): - """Constructor that requires an image environment observation space with a shape. - - Args: - env: The environment to wrap - shape: The resized observation shape - """ - assert isinstance(env.observation_space, spaces.Box) - assert len(env.observation_space.shape) in [2, 3] - assert np.all(env.observation_space.low == 0) and np.all( - env.observation_space.high == 255 - ) - assert env.observation_space.dtype == np.uint8 - - assert isinstance(shape, tuple) - assert all(np.issubdtype(type(elem), np.integer) for elem in shape) - assert all(x > 0 for x in shape) - - try: - import cv2 - except ImportError as e: - raise DependencyNotInstalled( - "opencv (cv2) is not installed, run `pip install gymnasium[other]`" - ) from e - - self.shape: Final[tuple[int, ...]] = tuple(shape) - - new_observation_space = spaces.Box( - low=0, - high=255, - shape=self.shape + env.observation_space.shape[2:], - dtype=np.uint8, - ) - - gym.utils.RecordConstructorArgs.__init__(self, shape=shape) - LambdaObservationV0.__init__( - self, - env=env, - func=lambda obs: cv2.resize(obs, self.shape, interpolation=cv2.INTER_AREA), - observation_space=new_observation_space, - ) - - -class ReshapeObservationV0( - LambdaObservationV0[WrapperObsType, ActType, ObsType], - gym.utils.RecordConstructorArgs, -): - """Reshapes array based observations to shapes. - - Example: - >>> import gymnasium as gym - >>> from gymnasium.experimental.wrappers import ReshapeObservationV0 - >>> env = gym.make("CarRacing-v2") - >>> env.observation_space.shape - (96, 96, 3) - >>> reshape_env = ReshapeObservationV0(env, (24, 4, 96, 1, 3)) - >>> reshape_env.observation_space.shape - (24, 4, 96, 1, 3) - """ - - def __init__(self, env: gym.Env[ObsType, ActType], shape: int | tuple[int, ...]): - """Constructor for env with ``Box`` observation space that has a shape product equal to the new shape product. - - Args: - env: The environment to wrap - shape: The reshaped observation space - """ - assert isinstance(env.observation_space, spaces.Box) - assert np.product(shape) == np.product(env.observation_space.shape) - - assert isinstance(shape, tuple) - assert all(np.issubdtype(type(elem), np.integer) for elem in shape) - assert all(x > 0 or x == -1 for x in shape) - - new_observation_space = spaces.Box( - low=np.reshape(np.ravel(env.observation_space.low), shape), - high=np.reshape(np.ravel(env.observation_space.high), shape), - shape=shape, - dtype=env.observation_space.dtype, - ) - self.shape = shape - - gym.utils.RecordConstructorArgs.__init__(self, shape=shape) - LambdaObservationV0.__init__( - self, - env=env, - func=lambda obs: np.reshape(obs, shape), - observation_space=new_observation_space, - ) - - -class RescaleObservationV0( - LambdaObservationV0[WrapperObsType, ActType, ObsType], - gym.utils.RecordConstructorArgs, -): - """Linearly rescales observation to between a minimum and maximum value. - - Example: - >>> import gymnasium as gym - >>> from gymnasium.experimental.wrappers import RescaleObservationV0 - >>> env = gym.make("Pendulum-v1") - >>> env.observation_space - Box([-1. -1. -8.], [1. 1. 8.], (3,), float32) - >>> env = RescaleObservationV0(env, np.array([-2, -1, -10], dtype=np.float32), np.array([1, 0, 1], dtype=np.float32)) - >>> env.observation_space - Box([ -2. -1. -10.], [1. 0. 1.], (3,), float32) - """ - - def __init__( - self, - env: gym.Env[ObsType, ActType], - min_obs: np.floating | np.integer | np.ndarray, - max_obs: np.floating | np.integer | np.ndarray, - ): - """Constructor that requires the env observation spaces to be a :class:`Box`. - - Args: - env: The environment to wrap - min_obs: The new minimum observation bound - max_obs: The new maximum observation bound - """ - assert isinstance(env.observation_space, spaces.Box) - assert not np.any(env.observation_space.low == np.inf) and not np.any( - env.observation_space.high == np.inf - ) - - if not isinstance(min_obs, np.ndarray): - assert np.issubdtype(type(min_obs), np.integer) or np.issubdtype( - type(max_obs), np.floating - ) - min_obs = np.full(env.observation_space.shape, min_obs) - assert ( - min_obs.shape == env.observation_space.shape - ), f"{min_obs.shape}, {env.observation_space.shape}, {min_obs}, {env.observation_space.low}" - assert not np.any(min_obs == np.inf) - - if not isinstance(max_obs, np.ndarray): - assert np.issubdtype(type(max_obs), np.integer) or np.issubdtype( - type(max_obs), np.floating - ) - max_obs = np.full(env.observation_space.shape, max_obs) - assert max_obs.shape == env.observation_space.shape - assert not np.any(max_obs == np.inf) - - self.min_obs = min_obs - self.max_obs = max_obs - - # Imagine the x-axis between the old Box and the y-axis being the new Box - gradient = (max_obs - min_obs) / ( - env.observation_space.high - env.observation_space.low - ) - intercept = gradient * -env.observation_space.low + min_obs - - gym.utils.RecordConstructorArgs.__init__(self, min_obs=min_obs, max_obs=max_obs) - LambdaObservationV0.__init__( - self, - env=env, - func=lambda obs: gradient * obs + intercept, - observation_space=spaces.Box( - low=min_obs, - high=max_obs, - shape=env.observation_space.shape, - dtype=env.observation_space.dtype, - ), - ) - - -class DtypeObservationV0( - LambdaObservationV0[WrapperObsType, ActType, ObsType], - gym.utils.RecordConstructorArgs, -): - """Observation wrapper for transforming the dtype of an observation. - - Note: - This is only compatible with :class:`Box`, :class:`Discrete`, :class:`MultiDiscrete` and :class:`MultiBinary` observation spaces - """ - - def __init__(self, env: gym.Env[ObsType, ActType], dtype: Any): - """Constructor for Dtype observation wrapper. - - Args: - env: The environment to wrap - dtype: The new dtype of the observation - """ - assert isinstance( - env.observation_space, - (spaces.Box, spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary), - ) - - self.dtype = dtype - if isinstance(env.observation_space, spaces.Box): - new_observation_space = spaces.Box( - low=env.observation_space.low, - high=env.observation_space.high, - shape=env.observation_space.shape, - dtype=self.dtype, - ) - elif isinstance(env.observation_space, spaces.Discrete): - new_observation_space = spaces.Box( - low=env.observation_space.start, - high=env.observation_space.start + env.observation_space.n, - shape=(), - dtype=self.dtype, - ) - elif isinstance(env.observation_space, spaces.MultiDiscrete): - new_observation_space = spaces.MultiDiscrete( - env.observation_space.nvec, dtype=dtype - ) - elif isinstance(env.observation_space, spaces.MultiBinary): - new_observation_space = spaces.Box( - low=0, - high=1, - shape=env.observation_space.shape, - dtype=self.dtype, - ) - else: - raise TypeError( - "DtypeObservation is only compatible with value / array-based observations." - ) - - gym.utils.RecordConstructorArgs.__init__(self, dtype=dtype) - LambdaObservationV0.__init__( - self, - env=env, - func=lambda obs: dtype(obs), - observation_space=new_observation_space, - ) - - -class PixelObservationV0( - LambdaObservationV0[WrapperObsType, ActType, ObsType], - gym.utils.RecordConstructorArgs, -): - """Includes the rendered observations to the environment's observations. - - Observations of this wrapper will be dictionaries of images. - You can also choose to add the observation of the base environment to this dictionary. - In that case, if the base environment has an observation space of type :class:`Dict`, the dictionary - of rendered images will be updated with the base environment's observation. If, however, the observation - space is of type :class:`Box`, the base environment's observation (which will be an element of the :class:`Box` - space) will be added to the dictionary under the key "state". - """ - - def __init__( - self, - env: gym.Env[ObsType, ActType], - pixels_only: bool = True, - pixels_key: str = "pixels", - obs_key: str = "state", - ): - """Constructor of the pixel observation wrapper. - - Args: - env: The environment to wrap. - pixels_only (bool): If ``True`` (default), the original observation returned - by the wrapped environment will be discarded, and a dictionary - observation will only include pixels. If ``False``, the - observation dictionary will contain both the original - observations and the pixel observations. - pixels_key: Optional custom string specifying the pixel key. Defaults to "pixels" - obs_key: Optional custom string specifying the obs key. Defaults to "state" - """ - gym.utils.RecordConstructorArgs.__init__( - self, pixels_only=pixels_only, pixels_key=pixels_key, obs_key=obs_key - ) - - assert env.render_mode is not None and env.render_mode != "human" - env.reset() - pixels = env.render() - assert pixels is not None and isinstance(pixels, np.ndarray) - pixel_space = spaces.Box(low=0, high=255, shape=pixels.shape, dtype=np.uint8) - - if pixels_only: - obs_space = pixel_space - LambdaObservationV0.__init__( - self, env=env, func=lambda _: self.render(), observation_space=obs_space - ) - elif isinstance(env.observation_space, spaces.Dict): - assert pixels_key not in env.observation_space.spaces.keys() - - obs_space = spaces.Dict( - {pixels_key: pixel_space, **env.observation_space.spaces} - ) - LambdaObservationV0.__init__( - self, - env=env, - func=lambda obs: {pixels_key: self.render(), **obs_space}, - observation_space=obs_space, - ) - else: - obs_space = spaces.Dict( - {obs_key: env.observation_space, pixels_key: pixel_space} - ) - LambdaObservationV0.__init__( - self, - env=env, - func=lambda obs: {obs_key: obs, pixels_key: self.render()}, - observation_space=obs_space, - ) diff --git a/gymnasium/experimental/wrappers/lambda_reward.py b/gymnasium/experimental/wrappers/lambda_reward.py deleted file mode 100644 index 980918e24..000000000 --- a/gymnasium/experimental/wrappers/lambda_reward.py +++ /dev/null @@ -1,102 +0,0 @@ -"""A collection of wrappers for modifying the reward. - -* ``LambdaRewardV0`` - Transforms the reward by a function -* ``ClipRewardV0`` - Clips the reward between a minimum and maximum value -""" -from __future__ import annotations - -from typing import Callable, SupportsFloat - -import numpy as np - -import gymnasium as gym -from gymnasium.core import ActType, ObsType -from gymnasium.error import InvalidBound - - -__all__ = ["LambdaRewardV0", "ClipRewardV0"] - - -class LambdaRewardV0( - gym.RewardWrapper[ObsType, ActType], gym.utils.RecordConstructorArgs -): - """A reward wrapper that allows a custom function to modify the step reward. - - Example: - >>> import gymnasium as gym - >>> from gymnasium.experimental.wrappers import LambdaRewardV0 - >>> env = gym.make("CartPole-v1") - >>> env = LambdaRewardV0(env, lambda r: 2 * r + 1) - >>> _ = env.reset() - >>> _, rew, _, _, _ = env.step(0) - >>> rew - 3.0 - """ - - def __init__( - self, - env: gym.Env[ObsType, ActType], - func: Callable[[SupportsFloat], SupportsFloat], - ): - """Initialize LambdaRewardV0 wrapper. - - Args: - env (Env): The environment to wrap - func: (Callable): The function to apply to reward - """ - gym.utils.RecordConstructorArgs.__init__(self, func=func) - gym.RewardWrapper.__init__(self, env) - - self.func = func - - def reward(self, reward: SupportsFloat) -> SupportsFloat: - """Apply function to reward. - - Args: - reward (Union[float, int, np.ndarray]): environment's reward - """ - return self.func(reward) - - -class ClipRewardV0(LambdaRewardV0[ObsType, ActType], gym.utils.RecordConstructorArgs): - """A wrapper that clips the rewards for an environment between an upper and lower bound. - - Example: - >>> import gymnasium as gym - >>> from gymnasium.experimental.wrappers import ClipRewardV0 - >>> env = gym.make("CartPole-v1") - >>> env = ClipRewardV0(env, 0, 0.5) - >>> _ = env.reset() - >>> _, rew, _, _, _ = env.step(1) - >>> rew - 0.5 - """ - - def __init__( - self, - env: gym.Env[ObsType, ActType], - min_reward: float | np.ndarray | None = None, - max_reward: float | np.ndarray | None = None, - ): - """Initialize ClipRewardsV0 wrapper. - - Args: - env (Env): The environment to wrap - min_reward (Union[float, np.ndarray]): lower bound to apply - max_reward (Union[float, np.ndarray]): higher bound to apply - """ - if min_reward is None and max_reward is None: - raise InvalidBound("Both `min_reward` and `max_reward` cannot be None") - - elif max_reward is not None and min_reward is not None: - if np.any(max_reward - min_reward < 0): - raise InvalidBound( - f"Min reward ({min_reward}) must be smaller than max reward ({max_reward})" - ) - - gym.utils.RecordConstructorArgs.__init__( - self, min_reward=min_reward, max_reward=max_reward - ) - LambdaRewardV0.__init__( - self, env=env, func=lambda x: np.clip(x, a_min=min_reward, a_max=max_reward) - ) diff --git a/gymnasium/experimental/wrappers/vector/__init__.py b/gymnasium/experimental/wrappers/vector/__init__.py deleted file mode 100644 index 14cadaa72..000000000 --- a/gymnasium/experimental/wrappers/vector/__init__.py +++ /dev/null @@ -1,146 +0,0 @@ -"""Wrappers for vector environments.""" -# pyright: reportUnsupportedDunderAll=false -import importlib -import re - -from gymnasium.error import DeprecatedWrapper -from gymnasium.experimental.wrappers.vector.dict_info_to_list import DictInfoToListV0 -from gymnasium.experimental.wrappers.vector.record_episode_statistics import ( - RecordEpisodeStatisticsV0, -) -from gymnasium.experimental.wrappers.vector.vectorize_action import ( - ClipActionV0, - LambdaActionV0, - RescaleActionV0, - VectorizeLambdaActionV0, -) -from gymnasium.experimental.wrappers.vector.vectorize_observation import ( - DtypeObservationV0, - FilterObservationV0, - FlattenObservationV0, - GrayscaleObservationV0, - LambdaObservationV0, - RescaleObservationV0, - ReshapeObservationV0, - ResizeObservationV0, - VectorizeLambdaObservationV0, -) -from gymnasium.experimental.wrappers.vector.vectorize_reward import ( - ClipRewardV0, - LambdaRewardV0, - VectorizeLambdaRewardV0, -) - - -__all__ = [ - # --- Vector only wrappers - "VectorizeLambdaObservationV0", - "VectorizeLambdaActionV0", - "VectorizeLambdaRewardV0", - "DictInfoToListV0", - # --- Observation wrappers --- - "LambdaObservationV0", - "FilterObservationV0", - "FlattenObservationV0", - "GrayscaleObservationV0", - "ResizeObservationV0", - "ReshapeObservationV0", - "RescaleObservationV0", - "DtypeObservationV0", - # "PixelObservationV0", - # "NormalizeObservationV0", - # "TimeAwareObservationV0", - # "FrameStackObservationV0", - # "DelayObservationV0", - # --- Action Wrappers --- - "LambdaActionV0", - "ClipActionV0", - "RescaleActionV0", - # --- Reward wrappers --- - "LambdaRewardV0", - "ClipRewardV0", - # "NormalizeRewardV1", - # --- Common --- - "RecordEpisodeStatisticsV0", - # --- Rendering --- - # "RenderCollectionV0", - # "RecordVideoV0", - # "HumanRenderingV0", - # --- Conversion --- - "JaxToNumpyV0", - "JaxToTorchV0", - "NumpyToTorchV0", -] - - -# As these wrappers requires `jax` or `torch`, they are loaded by runtime on users trying to access them -# to avoid `import jax` or `import torch` on `import gymnasium`. -_wrapper_to_class = { - # data converters - "JaxToNumpyV0": "jax_to_numpy", - "JaxToTorchV0": "jax_to_torch", - "NumpyToTorchV0": "numpy_to_torch", -} - - -def __getattr__(wrapper_name: str): - """Load a wrapper by name. - - This optimizes the loading of gymnasium wrappers by only loading the wrapper if it is used. - Errors will be raised if the wrapper does not exist or if the version is not the latest. - - Args: - wrapper_name: The name of a wrapper to load. - - Returns: - The specified wrapper. - - Raises: - AttributeError: If the wrapper does not exist. - DeprecatedWrapper: If the version is not the latest. - """ - # Check if the requested wrapper is in the _wrapper_to_class dictionary - if wrapper_name in _wrapper_to_class: - import_stmt = ( - f"gymnasium.experimental.wrappers.vector.{_wrapper_to_class[wrapper_name]}" - ) - module = importlib.import_module(import_stmt) - return getattr(module, wrapper_name) - - # Define a regex pattern to match the integer suffix (version number) of the wrapper - int_suffix_pattern = r"(\d+)$" - version_match = re.search(int_suffix_pattern, wrapper_name) - - # If a version number is found, extract it and the base wrapper name - if version_match: - version = int(version_match.group()) - base_name = wrapper_name[: -len(version_match.group())] - else: - version = float("inf") - base_name = wrapper_name[:-2] - - # Filter the list of all wrappers to include only those with the same base name - matching_wrappers = [name for name in __all__ if name.startswith(base_name)] - - # If no matching wrappers are found, raise an AttributeError - if not matching_wrappers: - raise AttributeError(f"module {__name__!r} has no attribute {wrapper_name!r}") - - # Find the latest version of the matching wrappers - latest_wrapper = max( - matching_wrappers, key=lambda s: int(re.findall(int_suffix_pattern, s)[0]) - ) - latest_version = int(re.findall(int_suffix_pattern, latest_wrapper)[0]) - - # If the requested wrapper is an older version, raise a DeprecatedWrapper exception - if version < latest_version: - raise DeprecatedWrapper( - f"{wrapper_name!r} is now deprecated, use {latest_wrapper!r} instead.\n" - f"To see the changes made, go to " - f"https://gymnasium.farama.org/api/experimental/vector-wrappers/#gymnasium.experimental.wrappers.vector.{latest_wrapper}" - ) - # If the requested version is invalid, raise an AttributeError - else: - raise AttributeError( - f"module {__name__!r} has no attribute {wrapper_name!r}, did you mean {latest_wrapper!r}" - ) diff --git a/gymnasium/experimental/wrappers/vector/dict_info_to_list.py b/gymnasium/experimental/wrappers/vector/dict_info_to_list.py deleted file mode 100644 index 85f30673d..000000000 --- a/gymnasium/experimental/wrappers/vector/dict_info_to_list.py +++ /dev/null @@ -1,86 +0,0 @@ -"""Wrapper that converts the info format for vec envs into the list format.""" -from __future__ import annotations - -from typing import Any - -from gymnasium.core import ActType, ObsType -from gymnasium.experimental.vector.vector_env import ArrayType, VectorEnv, VectorWrapper - - -__all__ = ["DictInfoToListV0"] - - -class DictInfoToListV0(VectorWrapper): - """Converts infos of vectorized environments from dict to List[dict]. - - This wrapper converts the info format of a - vector environment from a dictionary to a list of dictionaries. - This wrapper is intended to be used around vectorized - environments. If using other wrappers that perform - operation on info like `RecordEpisodeStatistics` this - need to be the outermost wrapper. - - i.e. ``DictInfoToListV0(RecordEpisodeStatisticsV0(vector_env))`` - - Example:: - - >>> import numpy as np - >>> dict_info = { - ... "k": np.array([0., 0., 0.5, 0.3]), - ... "_k": np.array([False, False, True, True]) - ... } - >>> list_info = [{}, {}, {"k": 0.5}, {"k": 0.3}] - """ - - def __init__(self, env: VectorEnv): - """This wrapper will convert the info into the list format. - - Args: - env (Env): The environment to apply the wrapper - """ - super().__init__(env) - - def step( - self, actions: ActType - ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, list[dict[str, Any]]]: - """Steps through the environment, convert dict info to list.""" - observation, reward, terminated, truncated, infos = self.env.step(actions) - list_info = self._convert_info_to_list(infos) - - return observation, reward, terminated, truncated, list_info - - def reset( - self, - *, - seed: int | list[int] | None = None, - options: dict[str, Any] | None = None, - ) -> tuple[ObsType, list[dict[str, Any]]]: - """Resets the environment using kwargs.""" - obs, infos = self.env.reset(seed=seed, options=options) - list_info = self._convert_info_to_list(infos) - - return obs, list_info - - def _convert_info_to_list(self, infos: dict) -> list[dict[str, Any]]: - """Convert the dict info to list. - - Convert the dict info of the vectorized environment - into a list of dictionaries where the i-th dictionary - has the info of the i-th environment. - - Args: - infos (dict): info dict coming from the env. - - Returns: - list_info (list): converted info. - - """ - list_info = [{} for _ in range(self.num_envs)] - list_info = self._process_episode_statistics(infos, list_info) - for k in infos: - if k.startswith("_"): - continue - for i, has_info in enumerate(infos[f"_{k}"]): - if has_info: - list_info[i][k] = infos[k][i] - return list_info diff --git a/gymnasium/experimental/wrappers/vector/vectorize_action.py b/gymnasium/experimental/wrappers/vector/vectorize_action.py deleted file mode 100644 index 1efbabada..000000000 --- a/gymnasium/experimental/wrappers/vector/vectorize_action.py +++ /dev/null @@ -1,143 +0,0 @@ -"""Vectorizes action wrappers to work for `VectorEnv`.""" -from __future__ import annotations - -from copy import deepcopy -from typing import Any, Callable - -import numpy as np - -from gymnasium import Space -from gymnasium.core import ActType, Env -from gymnasium.experimental.vector import VectorActionWrapper, VectorEnv -from gymnasium.experimental.wrappers import lambda_action -from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate - - -class LambdaActionV0(VectorActionWrapper): - """Transforms an action via a function provided to the wrapper. - - The function :attr:`func` will be applied to all vector actions. - If the observations from :attr:`func` are outside the bounds of the ``env``'s action space, provide an :attr:`action_space`. - """ - - def __init__( - self, - env: VectorEnv, - func: Callable[[ActType], Any], - action_space: Space | None = None, - ): - """Constructor for the lambda action wrapper. - - Args: - env: The vector environment to wrap - func: A function that will transform an action. If this transformed action is outside the action space of ``env.action_space`` then provide an ``action_space``. - action_space: The action spaces of the wrapper, if None, then it is assumed the same as ``env.action_space``. - """ - super().__init__(env) - - if action_space is not None: - self.action_space = action_space - - self.func = func - - def actions(self, actions: ActType) -> ActType: - """Applies the :attr:`func` to the actions.""" - return self.func(actions) - - -class VectorizeLambdaActionV0(VectorActionWrapper): - """Vectorizes a single-agent lambda action wrapper for vector environments.""" - - class VectorizedEnv(Env): - """Fake single-agent environment uses for the single-agent wrapper.""" - - def __init__(self, action_space: Space): - """Constructor for the fake environment.""" - self.action_space = action_space - - def __init__( - self, env: VectorEnv, wrapper: type[lambda_action.LambdaActionV0], **kwargs: Any - ): - """Constructor for the vectorized lambda action wrapper. - - Args: - env: The vector environment to wrap - wrapper: The wrapper to vectorize - **kwargs: Arguments for the LambdaActionV0 wrapper - """ - super().__init__(env) - - self.wrapper = wrapper( - self.VectorizedEnv(self.env.single_action_space), **kwargs - ) - self.single_action_space = self.wrapper.action_space - self.action_space = batch_space(self.single_action_space, self.num_envs) - - self.same_out = self.action_space == self.env.action_space - self.out = create_empty_array(self.single_action_space, self.num_envs) - - def actions(self, actions: ActType) -> ActType: - """Applies the wrapper to each of the action. - - Args: - actions: The actions to apply the function to - - Returns: - The updated actions using the wrapper func - """ - if self.same_out: - return concatenate( - self.single_action_space, - tuple( - self.wrapper.func(action) - for action in iterate(self.action_space, actions) - ), - actions, - ) - else: - return deepcopy( - concatenate( - self.single_action_space, - tuple( - self.wrapper.func(action) - for action in iterate(self.action_space, actions) - ), - self.out, - ) - ) - - -class ClipActionV0(VectorizeLambdaActionV0): - """Clip the continuous action within the valid :class:`Box` observation space bound.""" - - def __init__(self, env: VectorEnv): - """Constructor for the Clip Action wrapper. - - Args: - env: The vector environment to wrap - """ - super().__init__(env, lambda_action.ClipActionV0) - - -class RescaleActionV0(VectorizeLambdaActionV0): - """Affinely rescales the continuous action space of the environment to the range [min_action, max_action].""" - - def __init__( - self, - env: VectorEnv, - min_action: float | int | np.ndarray, - max_action: float | int | np.ndarray, - ): - """Initializes the :class:`RescaleAction` wrapper. - - Args: - env (Env): The vector environment to wrap - min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar. - max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar. - """ - super().__init__( - env, - lambda_action.RescaleActionV0, - min_action=min_action, - max_action=max_action, - ) diff --git a/gymnasium/experimental/wrappers/vector/vectorize_observation.py b/gymnasium/experimental/wrappers/vector/vectorize_observation.py deleted file mode 100644 index 5bb04ab6d..000000000 --- a/gymnasium/experimental/wrappers/vector/vectorize_observation.py +++ /dev/null @@ -1,222 +0,0 @@ -"""Vectorizes observation wrappers to works for `VectorEnv`.""" -from __future__ import annotations - -from copy import deepcopy -from typing import Any, Callable, Sequence - -import numpy as np - -from gymnasium import Space -from gymnasium.core import Env, ObsType -from gymnasium.experimental.vector import VectorEnv, VectorObservationWrapper -from gymnasium.experimental.vector.utils import batch_space, concatenate, iterate -from gymnasium.experimental.wrappers import lambda_observation -from gymnasium.vector.utils import create_empty_array - - -class LambdaObservationV0(VectorObservationWrapper): - """Transforms an observation via a function provided to the wrapper. - - The function :attr:`func` will be applied to all vector observations. - If the observations from :attr:`func` are outside the bounds of the ``env``'s observation space, provide an :attr:`observation_space`. - """ - - def __init__( - self, - env: VectorEnv, - vector_func: Callable[[ObsType], Any], - single_func: Callable[[ObsType], Any], - observation_space: Space | None = None, - ): - """Constructor for the lambda observation wrapper. - - Args: - env: The vector environment to wrap - vector_func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``. - single_func: A function that will transform an individual observation. - observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``. - """ - super().__init__(env) - - if observation_space is not None: - self.observation_space = observation_space - - self.vector_func = vector_func - self.single_func = single_func - - def vector_observation(self, observation: ObsType) -> ObsType: - """Apply function to the vector observation.""" - return self.vector_func(observation) - - def single_observation(self, observation: ObsType) -> ObsType: - """Apply function to the single observation.""" - return self.single_func(observation) - - -class VectorizeLambdaObservationV0(VectorObservationWrapper): - """Vectori`es a single-agent lambda observation wrapper for vector environments.""" - - class VectorizedEnv(Env): - """Fake single-agent environment uses for the single-agent wrapper.""" - - def __init__(self, observation_space: Space): - """Constructor for the fake environment.""" - self.observation_space = observation_space - - def __init__( - self, - env: VectorEnv, - wrapper: type[lambda_observation.LambdaObservationV0], - **kwargs: Any, - ): - """Constructor for the vectorized lambda observation wrapper. - - Args: - env: The vector environment to wrap. - wrapper: The wrapper to vectorize - **kwargs: Keyword argument for the wrapper - """ - super().__init__(env) - - self.wrapper = wrapper( - self.VectorizedEnv(self.env.single_observation_space), **kwargs - ) - self.single_observation_space = self.wrapper.observation_space - self.observation_space = batch_space( - self.single_observation_space, self.num_envs - ) - - self.same_out = self.observation_space == self.env.observation_space - self.out = create_empty_array(self.single_observation_space, self.num_envs) - - def vector_observation(self, observation: ObsType) -> ObsType: - """Iterates over the vector observations applying the single-agent wrapper ``observation`` then concatenates the observations together again.""" - if self.same_out: - return concatenate( - self.single_observation_space, - tuple( - self.wrapper.func(obs) - for obs in iterate(self.observation_space, observation) - ), - observation, - ) - else: - return deepcopy( - concatenate( - self.single_observation_space, - tuple( - self.wrapper.func(obs) - for obs in iterate(self.observation_space, observation) - ), - self.out, - ) - ) - - def single_observation(self, observation: ObsType) -> ObsType: - """Transforms a single observation using the wrapper transformation function.""" - return self.wrapper.func(observation) - - -class FilterObservationV0(VectorizeLambdaObservationV0): - """Vector wrapper for filtering dict or tuple observation spaces.""" - - def __init__(self, env: VectorEnv, filter_keys: Sequence[str | int]): - """Constructor for the filter observation wrapper. - - Args: - env: The vector environment to wrap - filter_keys: The subspaces to be included, use a list of strings or integers for ``Dict`` and ``Tuple`` spaces respectivesly - """ - super().__init__( - env, lambda_observation.FilterObservationV0, filter_keys=filter_keys - ) - - -class FlattenObservationV0(VectorizeLambdaObservationV0): - """Observation wrapper that flattens the observation.""" - - def __init__(self, env: VectorEnv): - """Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``. - - Args: - env: The vector environment to wrap - """ - super().__init__(env, lambda_observation.FlattenObservationV0) - - -class GrayscaleObservationV0(VectorizeLambdaObservationV0): - """Observation wrapper that converts an RGB image to grayscale.""" - - def __init__(self, env: VectorEnv, keep_dim: bool = False): - """Constructor for an RGB image based environments to make the image grayscale. - - Args: - env: The vector environment to wrap - keep_dim: If to keep the channel in the observation, if ``True``, ``obs.shape == 3`` else ``obs.shape == 2`` - """ - super().__init__( - env, lambda_observation.GrayscaleObservationV0, keep_dim=keep_dim - ) - - -class ResizeObservationV0(VectorizeLambdaObservationV0): - """Resizes image observations using OpenCV to shape.""" - - def __init__(self, env: VectorEnv, shape: tuple[int, ...]): - """Constructor that requires an image environment observation space with a shape. - - Args: - env: The vector environment to wrap - shape: The resized observation shape - """ - super().__init__(env, lambda_observation.ResizeObservationV0, shape=shape) - - -class ReshapeObservationV0(VectorizeLambdaObservationV0): - """Reshapes array based observations to shapes.""" - - def __init__(self, env: VectorEnv, shape: int | tuple[int, ...]): - """Constructor for env with Box observation space that has a shape product equal to the new shape product. - - Args: - env: The vector environment to wrap - shape: The reshaped observation space - """ - super().__init__(env, lambda_observation.ReshapeObservationV0, shape=shape) - - -class RescaleObservationV0(VectorizeLambdaObservationV0): - """Linearly rescales observation to between a minimum and maximum value.""" - - def __init__( - self, - env: VectorEnv, - min_obs: np.floating | np.integer | np.ndarray, - max_obs: np.floating | np.integer | np.ndarray, - ): - """Constructor that requires the env observation spaces to be a :class:`Box`. - - Args: - env: The vector environment to wrap - min_obs: The new minimum observation bound - max_obs: The new maximum observation bound - """ - super().__init__( - env, - lambda_observation.RescaleObservationV0, - min_obs=min_obs, - max_obs=max_obs, - ) - - -class DtypeObservationV0(VectorizeLambdaObservationV0): - """Observation wrapper for transforming the dtype of an observation.""" - - def __init__(self, env: VectorEnv, dtype: Any): - """Constructor for Dtype observation wrapper. - - Args: - env: The vector environment to wrap - dtype: The new dtype of the observation - """ - super().__init__(env, lambda_observation.DtypeObservationV0, dtype=dtype) diff --git a/gymnasium/experimental/wrappers/vector/vectorize_reward.py b/gymnasium/experimental/wrappers/vector/vectorize_reward.py deleted file mode 100644 index 059eb1ada..000000000 --- a/gymnasium/experimental/wrappers/vector/vectorize_reward.py +++ /dev/null @@ -1,78 +0,0 @@ -"""Vectorizes reward function to work with `VectorEnv`.""" -from __future__ import annotations - -from typing import Any, Callable - -import numpy as np - -from gymnasium import Env -from gymnasium.experimental.vector import VectorEnv, VectorRewardWrapper -from gymnasium.experimental.vector.vector_env import ArrayType -from gymnasium.experimental.wrappers import lambda_reward - - -class LambdaRewardV0(VectorRewardWrapper): - """A reward wrapper that allows a custom function to modify the step reward.""" - - def __init__(self, env: VectorEnv, func: Callable[[ArrayType], ArrayType]): - """Initialize LambdaRewardV0 wrapper. - - Args: - env (Env): The vector environment to wrap - func: (Callable): The function to apply to reward - """ - super().__init__(env) - - self.func = func - - def reward(self, reward: ArrayType) -> ArrayType: - """Apply function to reward.""" - return self.func(reward) - - -class VectorizeLambdaRewardV0(VectorRewardWrapper): - """Vectorizes a single-agent lambda reward wrapper for vector environments.""" - - def __init__( - self, env: VectorEnv, wrapper: type[lambda_reward.LambdaRewardV0], **kwargs: Any - ): - """Constructor for the vectorized lambda reward wrapper. - - Args: - env: The vector environment to wrap. - wrapper: The wrapper to vectorize - **kwargs: Keyword argument for the wrapper - """ - super().__init__(env) - - self.wrapper = wrapper(Env(), **kwargs) - - def reward(self, reward: ArrayType) -> ArrayType: - """Iterates over the reward updating each with the wrapper func.""" - for i, r in enumerate(reward): - reward[i] = self.wrapper.func(r) - return reward - - -class ClipRewardV0(VectorizeLambdaRewardV0): - """A wrapper that clips the rewards for an environment between an upper and lower bound.""" - - def __init__( - self, - env: VectorEnv, - min_reward: float | np.ndarray | None = None, - max_reward: float | np.ndarray | None = None, - ): - """Constructor for ClipReward wrapper. - - Args: - env: The vector environment to wrap - min_reward: The min reward for each step - max_reward: the max reward for each step - """ - super().__init__( - env, - lambda_reward.ClipRewardV0, - min_reward=min_reward, - max_reward=max_reward, - ) diff --git a/gymnasium/experimental/functional.py b/gymnasium/functional.py similarity index 75% rename from gymnasium/experimental/functional.py rename to gymnasium/functional.py index 899e4d67c..4fb66b6f8 100644 --- a/gymnasium/experimental/functional.py +++ b/gymnasium/functional.py @@ -24,13 +24,14 @@ class FuncEnv( This API is meant to be used in a stateless manner, with the environment state being passed around explicitly. That being said, nothing here prevents users from using the environment statefully, it's just not recommended. A functional env consists of the following functions (in this case, instance methods): - - initial: returns the initial state of the POMDP - - observation: returns the observation in a given state - - transition: returns the next state after taking an action in a given state - - reward: returns the reward for a given (state, action, next_state) tuple - - terminal: returns whether a given state is terminal - - state_info: optional, returns a dict of info about a given state - - step_info: optional, returns a dict of info about a given (state, action, next_state) tuple + + * initial: returns the initial state of the POMDP + * observation: returns the observation in a given state + * transition: returns the next state after taking an action in a given state + * reward: returns the reward for a given (state, action, next_state) tuple + * terminal: returns whether a given state is terminal + * state_info: optional, returns a dict of info about a given state + * step_info: optional, returns a dict of info about a given (state, action, next_state) tuple The class-based structure serves the purpose of allowing environment constants to be defined in the class, and then using them by name in the code itself. @@ -47,32 +48,32 @@ class FuncEnv( self.__dict__.update(options or {}) def initial(self, rng: Any) -> StateType: - """Initial state.""" + """Generates the initial state of the environment with a random number generator.""" raise NotImplementedError def transition(self, state: StateType, action: ActType, rng: Any) -> StateType: - """Transition.""" + """Updates (transitions) the state with an action and random number generator.""" raise NotImplementedError def observation(self, state: StateType) -> ObsType: - """Observation.""" + """Generates an observation for a given state of an environment.""" raise NotImplementedError def reward( self, state: StateType, action: ActType, next_state: StateType ) -> RewardType: - """Reward.""" + """Computes the reward for a given transition between `state`, `action` to `next_state`.""" raise NotImplementedError def terminal(self, state: StateType) -> TerminalType: - """Terminal state.""" + """Returns if the state is a final terminal state.""" raise NotImplementedError def state_info(self, state: StateType) -> dict: """Info dict about a single state.""" return {} - def step_info( + def transition_info( self, state: StateType, action: ActType, next_state: StateType ) -> dict: """Info dict about a full transition.""" @@ -82,11 +83,13 @@ class FuncEnv( """Functional transformations.""" self.initial = func(self.initial) self.transition = func(self.transition) + self.observation = func(self.observation) self.reward = func(self.reward) self.terminal = func(self.terminal) + self.state_info = func(self.state_info) - self.step_info = func(self.step_info) + self.transition_info = func(self.transition_info) def render_image( self, state: StateType, render_state: RenderStateType diff --git a/gymnasium/spaces/box.py b/gymnasium/spaces/box.py index 3b75ab4fc..418cb7158 100644 --- a/gymnasium/spaces/box.py +++ b/gymnasium/spaces/box.py @@ -274,7 +274,7 @@ class Box(Space[NDArray[Any]]): return ( isinstance(other, Box) and (self.shape == other.shape) - # and (self.dtype == other.dtype) + and (self.dtype == other.dtype) and np.allclose(self.low, other.low) and np.allclose(self.high, other.high) ) diff --git a/gymnasium/spaces/dict.py b/gymnasium/spaces/dict.py index 039759da8..6feb6c067 100644 --- a/gymnasium/spaces/dict.py +++ b/gymnasium/spaces/dict.py @@ -45,7 +45,7 @@ class Dict(Space[typing.Dict[str, Any]], typing.Mapping[str, Space[Any]]): It can be convenient to use :class:`Dict` spaces if you want to make complex observations or actions more human-readable. Usually, it will not be possible to use elements of this space directly in learning code. However, you can easily - convert `Dict` observations to flat arrays by using a :class:`gymnasium.wrappers.FlattenObservation` wrapper. + convert :class:`Dict` observations to flat arrays by using a :class:`gymnasium.wrappers.FlattenObservation` wrapper. Similar wrappers can be implemented to deal with :class:`Dict` actions. """ diff --git a/gymnasium/spaces/discrete.py b/gymnasium/spaces/discrete.py index 7bc441dbd..c6f89922d 100644 --- a/gymnasium/spaces/discrete.py +++ b/gymnasium/spaces/discrete.py @@ -62,8 +62,8 @@ class Discrete(Space[np.int64]): Args: mask: An optional mask for if an action can be selected. - Expected `np.ndarray` of shape `(n,)` and dtype `np.int8` where `1` represents valid actions and `0` invalid / infeasible actions. - If there are no possible actions (i.e. `np.all(mask == 0)`) then `space.start` will be returned. + Expected `np.ndarray` of shape ``(n,)`` and dtype ``np.int8`` where ``1`` represents valid actions and ``0`` invalid / infeasible actions. + If there are no possible actions (i.e. ``np.all(mask == 0)``) then ``space.start`` will be returned. Returns: A sampled integer from the space diff --git a/gymnasium/spaces/graph.py b/gymnasium/spaces/graph.py index d2be47e8a..ce7639819 100644 --- a/gymnasium/spaces/graph.py +++ b/gymnasium/spaces/graph.py @@ -27,7 +27,7 @@ class GraphInstance(NamedTuple): class Graph(Space[GraphInstance]): - r"""A space representing graph information as a series of `nodes` connected with `edges` according to an adjacency matrix represented as a series of `edge_links`. + r"""A space representing graph information as a series of ``nodes`` connected with ``edges`` according to an adjacency matrix represented as a series of ``edge_links``. Example: >>> from gymnasium.spaces import Graph, Box, Discrete @@ -122,14 +122,14 @@ class Graph(Space[GraphInstance]): num_nodes: int = 10, num_edges: int | None = None, ) -> GraphInstance: - """Generates a single sample graph with num_nodes between 1 and 10 sampled from the Graph. + """Generates a single sample graph with num_nodes between ``1`` and ``10`` sampled from the Graph. Args: mask: An optional tuple of optional node and edge mask that is only possible with Discrete spaces (Box spaces don't support sample masks). - If no `num_edges` is provided then the `edge_mask` is multiplied by the number of edges - num_nodes: The number of nodes that will be sampled, the default is 10 nodes - num_edges: An optional number of edges, otherwise, a random number between 0 and `num_nodes` ^ 2 + If no ``num_edges`` is provided then the ``edge_mask`` is multiplied by the number of edges + num_nodes: The number of nodes that will be sampled, the default is `10` nodes + num_edges: An optional number of edges, otherwise, a random number between `0` and :math:`num_nodes^2` Returns: A :class:`GraphInstance` with attributes `.nodes`, `.edges`, and `.edge_links`. @@ -212,7 +212,7 @@ class Graph(Space[GraphInstance]): def __repr__(self) -> str: """A string representation of this space. - The representation will include node_space and edge_space + The representation will include ``node_space`` and ``edge_space`` Returns: A representation of the space diff --git a/gymnasium/spaces/multi_binary.py b/gymnasium/spaces/multi_binary.py index 635003328..fe0c40239 100644 --- a/gymnasium/spaces/multi_binary.py +++ b/gymnasium/spaces/multi_binary.py @@ -65,8 +65,8 @@ class MultiBinary(Space[NDArray[np.int8]]): Args: mask: An optional np.ndarray to mask samples with expected shape of ``space.shape``. - For mask == 0 then the samples will be 0 and mask == 1 then random samples will be generated. - The expected mask shape is the space shape and mask dtype is `np.int8`. + For ``mask == 0`` then the samples will be ``0`` and ``mask == 1` then random samples will be generated. + The expected mask shape is the space shape and mask dtype is ``np.int8``. Returns: Sampled values from space diff --git a/gymnasium/spaces/multi_discrete.py b/gymnasium/spaces/multi_discrete.py index fcde1059d..09f930ff8 100644 --- a/gymnasium/spaces/multi_discrete.py +++ b/gymnasium/spaces/multi_discrete.py @@ -87,12 +87,12 @@ class MultiDiscrete(Space[NDArray[np.integer]]): """Generates a single random sample this space. Args: - mask: An optional mask for multi-discrete, expects tuples with a `np.ndarray` mask in the position of each - action with shape `(n,)` where `n` is the number of actions and `dtype=np.int8`. - Only mask values == 1 are possible to sample unless all mask values for an action are 0 then the default action `self.start` (the smallest element) is sampled. + mask: An optional mask for multi-discrete, expects tuples with a ``np.ndarray`` mask in the position of each + action with shape ``(n,)`` where ``n`` is the number of actions and ``dtype=np.int8``. + Only ``mask values == 1`` are possible to sample unless all mask values for an action are ``0`` then the default action ``self.start`` (the smallest element) is sampled. Returns: - An `np.ndarray` of shape `space.shape` + An ``np.ndarray`` of :meth:`Space.shape` """ if mask is not None: @@ -206,6 +206,7 @@ class MultiDiscrete(Space[NDArray[np.integer]]): """Check whether ``other`` is equivalent to this instance.""" return bool( isinstance(other, MultiDiscrete) + and self.dtype == other.dtype and np.all(self.nvec == other.nvec) and np.all(self.start == other.start) ) diff --git a/gymnasium/spaces/sequence.py b/gymnasium/spaces/sequence.py index 1311316fe..df221684d 100644 --- a/gymnasium/spaces/sequence.py +++ b/gymnasium/spaces/sequence.py @@ -38,7 +38,7 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]): Args: space: Elements in the sequences this space represent must belong to this space. seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space. - stack: If `True` then the resulting samples would be stacked. + stack: If ``True`` then the resulting samples would be stacked. """ assert isinstance( space, Space @@ -78,14 +78,13 @@ class Sequence(Space[Union[typing.Tuple[Any, ...], Any]]): Args: mask: An optional mask for (optionally) the length of the sequence and (optionally) the values in the sequence. - If you specify `mask`, it is expected to be a tuple of the form `(length_mask, sample_mask)` where `length_mask` - is + If you specify ``mask``, it is expected to be a tuple of the form ``(length_mask, sample_mask)`` where ``length_mask`` is * ``None`` The length will be randomly drawn from a geometric distribution * ``np.ndarray`` of integers, in which case the length of the sampled sequence is randomly drawn from this array. * ``int`` for a fixed length sample - The second element of the mask tuple `sample` mask specifies a mask that is applied when + The second element of the mask tuple ``sample`` mask specifies a mask that is applied when sampling elements from the base space. The mask is applied for each feature space sample. Returns: diff --git a/gymnasium/spaces/text.py b/gymnasium/spaces/text.py index be15be13f..5114b2ec3 100644 --- a/gymnasium/spaces/text.py +++ b/gymnasium/spaces/text.py @@ -78,13 +78,13 @@ class Text(Space[str]): self, mask: None | (tuple[int | None, NDArray[np.int8] | None]) = None, ) -> str: - """Generates a single random sample from this space with by default a random length between `min_length` and `max_length` and sampled from the `charset`. + """Generates a single random sample from this space with by default a random length between ``min_length`` and ``max_length`` and sampled from the ``charset``. Args: mask: An optional tuples of length and mask for the text. - The length is expected to be between the `min_length` and `max_length` otherwise a random integer between `min_length` and `max_length` is selected. - For the mask, we expect a numpy array of length of the charset passed with `dtype == np.int8`. - If the charlist mask is all zero then an empty string is returned no matter the `min_length` + The length is expected to be between the ``min_length`` and ``max_length`` otherwise a random integer between ``min_length`` and ``max_length`` is selected. + For the mask, we expect a numpy array of length of the charset passed with ``dtype == np.int8``. + If the charlist mask is all zero then an empty string is returned no matter the ``min_length`` Returns: A sampled string from the space diff --git a/gymnasium/spaces/tuple.py b/gymnasium/spaces/tuple.py index f117845f8..35d768d1c 100644 --- a/gymnasium/spaces/tuple.py +++ b/gymnasium/spaces/tuple.py @@ -53,8 +53,8 @@ class Tuple(Space[typing.Tuple[Any, ...]], typing.Sequence[Any]): Depending on the type of seed, the subspaces will be seeded differently * ``None`` - All the subspaces will use a random initial seed - * ``Int`` - The integer is used to seed the `Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all of the subspaces. - * ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces (``List(42, 54, ...``). + * ``Int`` - The integer is used to seed the :class:`Tuple` space that is used to generate seed values for each of the subspaces. Warning, this does not guarantee unique seeds for all the subspaces. + * ``List`` - Values used to seed the subspaces. This allows the seeding of multiple composite subspaces ``[42, 54, ...]``. Args: seed: An optional list of ints or int to seed the (sub-)spaces. diff --git a/gymnasium/spaces/utils.py b/gymnasium/spaces/utils.py index 4265bb8d0..31fef9df7 100644 --- a/gymnasium/spaces/utils.py +++ b/gymnasium/spaces/utils.py @@ -428,9 +428,7 @@ def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph: Raises: NotImplementedError: if the space is not defined in :mod:`gymnasium.spaces`. - Example: - Flatten spaces.Box: - + Example - Flatten spaces.Box: >>> from gymnasium.spaces import Box >>> box = Box(0.0, 1.0, shape=(3, 4, 5)) >>> box @@ -440,8 +438,7 @@ def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph: >>> flatten(box, box.sample()) in flatten_space(box) True - Flatten spaces.Discrete: - + Example - Flatten spaces.Discrete: >>> from gymnasium.spaces import Discrete >>> discrete = Discrete(5) >>> flatten_space(discrete) @@ -449,8 +446,7 @@ def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph: >>> flatten(discrete, discrete.sample()) in flatten_space(discrete) True - Flatten spaces.Dict: - + Example - Flatten spaces.Dict: >>> from gymnasium.spaces import Dict, Discrete, Box >>> space = Dict({"position": Discrete(2), "velocity": Box(0, 1, shape=(2, 2))}) >>> flatten_space(space) @@ -458,8 +454,7 @@ def flatten_space(space: Space[Any]) -> Box | Dict | Sequence | Tuple | Graph: >>> flatten(space, space.sample()) in flatten_space(space) True - Flatten spaces.Graph: - + Example - Flatten spaces.Graph: >>> from gymnasium.spaces import Graph, Discrete, Box >>> space = Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)) >>> flatten_space(space) diff --git a/gymnasium/utils/env_checker.py b/gymnasium/utils/env_checker.py index 1737f9ec9..6482b86ea 100644 --- a/gymnasium/utils/env_checker.py +++ b/gymnasium/utils/env_checker.py @@ -1,4 +1,4 @@ -"""A set of functions for checking an environment details. +"""A set of functions for checking an environment implementation. This file is originally from the Stable Baselines3 repository hosted on GitHub (https://github.com/DLR-RM/stable-baselines3/) @@ -63,7 +63,7 @@ def data_equivalence(data_1, data_2) -> bool: return False -def check_reset_seed(env: gym.Env) -> None: +def check_reset_seed(env: gym.Env): """Check that the environment can be reset with a seed. Args: @@ -132,7 +132,7 @@ def check_reset_seed(env: gym.Env) -> None: ) -def check_reset_options(env: gym.Env) -> None: +def check_reset_options(env: gym.Env): """Check that the environment can be reset with options. Args: @@ -160,7 +160,7 @@ def check_reset_options(env: gym.Env) -> None: ) -def check_reset_return_info_deprecation(env: gym.Env) -> None: +def check_reset_return_info_deprecation(env: gym.Env): """Makes sure support for deprecated `return_info` argument is dropped. Args: @@ -177,7 +177,7 @@ def check_reset_return_info_deprecation(env: gym.Env) -> None: ) -def check_seed_deprecation(env: gym.Env) -> None: +def check_seed_deprecation(env: gym.Env): """Makes sure support for deprecated function `seed` is dropped. Args: @@ -193,7 +193,7 @@ def check_seed_deprecation(env: gym.Env) -> None: ) -def check_reset_return_type(env: gym.Env) -> None: +def check_reset_return_type(env: gym.Env): """Checks that :meth:`reset` correctly returns a tuple of the form `(obs , info)`. Args: @@ -218,7 +218,7 @@ def check_reset_return_type(env: gym.Env) -> None: ), f"The second element returned by `env.reset()` was not a dictionary, actual type: {type(info)}" -def check_space_limit(space: spaces.Space, space_type: str) -> None: +def check_space_limit(space, space_type: str): """Check the space limit for only the Box space as a test that only runs as part of `check_env`.""" if isinstance(space, spaces.Box): if np.any(np.equal(space.low, -np.inf)): @@ -256,18 +256,19 @@ def check_space_limit(space: spaces.Space, space_type: str) -> None: check_space_limit(subspace, space_type) -def check_env(env: gym.Env, warn: bool = None, skip_render_check: bool = False) -> None: - """Check that an environment follows Gym API. +def check_env(env: gym.Env, warn: bool = None, skip_render_check: bool = False): + """Check that an environment follows Gymnasium's API. - This is an invasive function that calls the environment's reset and step. + .. py:currentmodule:: gymnasium.Env - This is particularly useful when using a custom environment. - Please take a look at https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/ - for more information about the API. + To ensure that an environment is implemented "correctly", ``check_env`` checks that the :attr:`observation_space` and :attr:`action_space` are correct. + Furthermore, the function will call the :meth:`reset`, :meth:`step` and :meth:`render` functions with a variety of values. + + We highly recommend users calling this function after an environment is constructed and within a projects continuous integration to keep an environment update with Gymnasium's API. Args: env: The Gym environment that will be checked - warn: Ignored + warn: Ignored, previously silenced particular warnings skip_render_check: Whether to skip the checks for the render method. True by default (useful for the CI) """ if warn is not None: diff --git a/gymnasium/utils/passive_env_checker.py b/gymnasium/utils/passive_env_checker.py index ff448536a..d9272dafe 100644 --- a/gymnasium/utils/passive_env_checker.py +++ b/gymnasium/utils/passive_env_checker.py @@ -12,6 +12,8 @@ __all__ = [ "env_render_passive_checker", "env_reset_passive_checker", "env_step_passive_checker", + "check_action_space", + "check_observation_space", ] diff --git a/gymnasium/utils/play.py b/gymnasium/utils/play.py index e1bddcba6..7c461691b 100644 --- a/gymnasium/utils/play.py +++ b/gymnasium/utils/play.py @@ -1,6 +1,8 @@ """Utilities of visualising an environment.""" +from __future__ import annotations + from collections import deque -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, List import numpy as np @@ -40,8 +42,8 @@ class PlayableGame: def __init__( self, env: Env, - keys_to_action: Optional[Dict[Tuple[int, ...], int]] = None, - zoom: Optional[float] = None, + keys_to_action: dict[tuple[int, ...], int] | None = None, + zoom: float | None = None, ): """Wraps an environment with a dictionary of keyboard buttons to action and if to zoom in on the environment. @@ -66,7 +68,7 @@ class PlayableGame: self.running = True def _get_relevant_keys( - self, keys_to_action: Optional[Dict[Tuple[int], int]] = None + self, keys_to_action: dict[tuple[int], int] | None = None ) -> set: if keys_to_action is None: if hasattr(self.env, "get_keys_to_action"): @@ -83,7 +85,7 @@ class PlayableGame: relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), [])) return relevant_keys - def _get_video_size(self, zoom: Optional[float] = None) -> Tuple[int, int]: + def _get_video_size(self, zoom: float | None = None) -> tuple[int, int]: rendered = self.env.render() if isinstance(rendered, List): rendered = rendered[-1] @@ -123,7 +125,7 @@ class PlayableGame: def display_arr( - screen: Surface, arr: np.ndarray, video_size: Tuple[int, int], transpose: bool + screen: Surface, arr: np.ndarray, video_size: tuple[int, int], transpose: bool ): """Displays a numpy array on screen. @@ -147,15 +149,15 @@ def display_arr( def play( env: Env, - transpose: Optional[bool] = True, - fps: Optional[int] = None, - zoom: Optional[float] = None, - callback: Optional[Callable] = None, - keys_to_action: Optional[Dict[Union[Tuple[Union[str, int]], str], ActType]] = None, - seed: Optional[int] = None, + transpose: bool | None = True, + fps: int | None = None, + zoom: float | None = None, + callback: Callable | None = None, + keys_to_action: dict[tuple[str | int] | str, ActType] | None = None, + seed: int | None = None, noop: ActType = 0, ): - """Allows one to play the game using keyboard. + """Allows the user to play the environment using a keyboard. Args: env: Environment to use for playing. @@ -164,13 +166,14 @@ def play( ``env.metadata["render_fps""]`` (or 30, if the environment does not specify "render_fps") is used. zoom: Zoom the observation in, ``zoom`` amount, should be positive float callback: If a callback is provided, it will be executed after every step. It takes the following input: - obs_t: observation before performing action - obs_tp1: observation after performing action - action: action that was executed - rew: reward that was received - terminated: whether the environment is terminated or not - truncated: whether the environment is truncated or not - info: debug info + + * obs_t: observation before performing action + * obs_tp1: observation after performing action + * action: action that was executed + * rew: reward that was received + * terminated: whether the environment is terminated or not + * truncated: whether the environment is truncated or not + * info: debug info keys_to_action: Mapping from keys pressed to action performed. Different formats are supported: Key combinations can either be expressed as a tuple of unicode code points of the keys, as a tuple of characters, or as a string where each character of the string represents @@ -205,28 +208,29 @@ def play( noop: The action used when no key input has been entered, or the entered key combination is unknown. Example: - >>> import gymnasium as gym >>> from gymnasium.utils.play import play - >>> play(gym.make("CarRacing-v2", render_mode="rgb_array"), keys_to_action={ # doctest: +SKIP - ... "w": np.array([0, 0.7, 0]), - ... "a": np.array([-1, 0, 0]), - ... "s": np.array([0, 0, 1]), - ... "d": np.array([1, 0, 0]), - ... "wa": np.array([-1, 0.7, 0]), - ... "dw": np.array([1, 0.7, 0]), - ... "ds": np.array([1, 0, 1]), - ... "as": np.array([-1, 0, 1]), - ... }, noop=np.array([0,0,0])) + >>> play(gym.make("CarRacing-v2", render_mode="rgb_array"), # doctest: +SKIP + ... keys_to_action={ + ... "w": np.array([0, 0.7, 0]), + ... "a": np.array([-1, 0, 0]), + ... "s": np.array([0, 0, 1]), + ... "d": np.array([1, 0, 0]), + ... "wa": np.array([-1, 0.7, 0]), + ... "dw": np.array([1, 0.7, 0]), + ... "ds": np.array([1, 0, 1]), + ... "as": np.array([-1, 0, 1]), + ... }, + ... noop=np.array([0, 0, 0]) + ... ) Above code works also if the environment is wrapped, so it's particularly useful in verifying that the frame-level preprocessing does not render the game unplayable. If you wish to plot real time statistics as you play, you can use - :class:`gym.utils.play.PlayPlot`. Here's a sample code for plotting the reward + :class:`PlayPlot`. Here's a sample code for plotting the reward for last 150 steps. - >>> import gymnasium as gym >>> from gymnasium.utils.play import PlayPlot, play >>> def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info): ... return [rew,] @@ -321,7 +325,7 @@ class PlayPlot: """ def __init__( - self, callback: Callable, horizon_timesteps: int, plot_names: List[str] + self, callback: Callable, horizon_timesteps: int, plot_names: list[str] ): """Constructor of :class:`PlayPlot`. @@ -355,7 +359,7 @@ class PlayPlot: for axis, name in zip(self.ax, plot_names): axis.set_title(name) self.t = 0 - self.cur_plot: List[Optional[plt.Axes]] = [None for _ in range(num_plots)] + self.cur_plot: list[plt.Axes | None] = [None for _ in range(num_plots)] self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)] def callback( diff --git a/gymnasium/utils/save_video.py b/gymnasium/utils/save_video.py index 7ca4d019c..9d9dc7a04 100644 --- a/gymnasium/utils/save_video.py +++ b/gymnasium/utils/save_video.py @@ -15,9 +15,9 @@ except ImportError as e: def capped_cubic_video_schedule(episode_id: int) -> bool: - """The default episode trigger. + r"""The default episode trigger. - This function will trigger recordings at the episode indices 0, 1, 4, 8, 27, ..., :math:`k^3`, ..., 729, 1000, 2000, 3000, ... + This function will trigger recordings at the episode indices :math:`\{0, 1, 4, 8, 27, ..., k^3, ..., 729, 1000, 2000, 3000, ...\}` Args: episode_id: The episode number diff --git a/gymnasium/utils/seeding.py b/gymnasium/utils/seeding.py index a93dfbe5b..95625f397 100644 --- a/gymnasium/utils/seeding.py +++ b/gymnasium/utils/seeding.py @@ -1,22 +1,29 @@ """Set of random number generator functions: seeding, generator, hashing seeds.""" -from typing import Any, Optional, Tuple +from __future__ import annotations import numpy as np from gymnasium import error -def np_random(seed: Optional[int] = None) -> Tuple[np.random.Generator, Any]: - """Generates a random number generator from the seed and returns the Generator and seed. +def np_random(seed: int | None = None) -> tuple[np.random.Generator, int]: + """Returns a NumPy random number generator (RNG) along with seed value from the inputted seed. + + If ``seed`` is ``None`` then a **random** seed will be generated as the RNG's initial seed. + This randomly selected seed is returned as the second value of the tuple. + + .. py:currentmodule:: gymnasium.Env + + This function is called in :meth:`reset` to reset an environment's initial RNG. Args: seed: The seed used to create the generator Returns: - The generator and resulting seed + A NumPy-based Random Number Generator and generator seed Raises: - Error: Seed must be a non-negative integer or omitted + Error: Seed must be a non-negative integer """ if seed is not None and not (isinstance(seed, int) and 0 <= seed): if isinstance(seed, int) is False: diff --git a/gymnasium/utils/step_api_compatibility.py b/gymnasium/utils/step_api_compatibility.py index d0b0d71c2..e56e882c4 100644 --- a/gymnasium/utils/step_api_compatibility.py +++ b/gymnasium/utils/step_api_compatibility.py @@ -1,4 +1,6 @@ """Contains methods for step compatibility, from old-to-new and new-to-old API.""" +from __future__ import annotations + from typing import SupportsFloat, Tuple, Union import numpy as np @@ -23,13 +25,15 @@ TerminatedTruncatedStepType = Tuple[ def convert_to_terminated_truncated_step_api( - step_returns: Union[DoneStepType, TerminatedTruncatedStepType], is_vector_env=False + step_returns: DoneStepType | TerminatedTruncatedStepType, is_vector_env=False ) -> TerminatedTruncatedStepType: """Function to transform step returns to new step API irrespective of input API. + .. py:currentmodule:: gymnasium.Env + Args: - step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info) - is_vector_env (bool): Whether the step_returns are from a vector environment + step_returns (tuple): Items returned by :meth:`step`. Can be ``(obs, rew, done, info)`` or ``(obs, rew, terminated, truncated, info)`` + is_vector_env (bool): Whether the ``step_returns`` are from a vector environment """ if len(step_returns) == 5: return step_returns @@ -75,14 +79,16 @@ def convert_to_terminated_truncated_step_api( def convert_to_done_step_api( - step_returns: Union[TerminatedTruncatedStepType, DoneStepType], + step_returns: TerminatedTruncatedStepType | DoneStepType, is_vector_env: bool = False, ) -> DoneStepType: """Function to transform step returns to old step API irrespective of input API. + .. py:currentmodule:: gymnasium.Env + Args: - step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info) - is_vector_env (bool): Whether the step_returns are from a vector environment + step_returns (tuple): Items returned by :meth:`step`. Can be ``(obs, rew, done, info)`` or ``(obs, rew, terminated, truncated, info)`` + is_vector_env (bool): Whether the ``step_returns`` are from a vector environment """ if len(step_returns) == 4: return step_returns @@ -130,38 +136,41 @@ def convert_to_done_step_api( def step_api_compatibility( - step_returns: Union[TerminatedTruncatedStepType, DoneStepType], + step_returns: TerminatedTruncatedStepType | DoneStepType, output_truncation_bool: bool = True, is_vector_env: bool = False, -) -> Union[TerminatedTruncatedStepType, DoneStepType]: - """Function to transform step returns to the API specified by `output_truncation_bool` bool. +) -> TerminatedTruncatedStepType | DoneStepType: + """Function to transform step returns to the API specified by ``output_truncation_bool``. - Done (old) step API refers to step() method returning (observation, reward, done, info) - Terminated Truncated (new) step API refers to step() method returning (observation, reward, terminated, truncated, info) + .. py:currentmodule:: gymnasium.Env + + Done (old) step API refers to :meth:`step` method returning ``(observation, reward, done, info)`` + Terminated Truncated (new) step API refers to :meth:`step` method returning ``(observation, reward, terminated, truncated, info)`` (Refer to docs for details on the API change) Args: - step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info) - output_truncation_bool (bool): Whether the output should return two booleans (new API) or one (old) (True by default) - is_vector_env (bool): Whether the step_returns are from a vector environment + step_returns (tuple): Items returned by :meth:`step`. Can be ``(obs, rew, done, info)`` or ``(obs, rew, terminated, truncated, info)`` + output_truncation_bool (bool): Whether the output should return two booleans (new API) or one (old) (``True`` by default) + is_vector_env (bool): Whether the ``step_returns`` are from a vector environment Returns: - step_returns (tuple): Depending on `output_truncation_bool` bool, it can return `(obs, rew, done, info)` or `(obs, rew, terminated, truncated, info)` + step_returns (tuple): Depending on ``output_truncation_bool``, it can return ``(obs, rew, done, info)`` or ``(obs, rew, terminated, truncated, info)`` Example: - This function can be used to ensure compatibility in step interfaces with conflicting API. Eg. if env is written in old API, - wrapper is written in new API, and the final step output is desired to be in old API. + This function can be used to ensure compatibility in step interfaces with conflicting API. E.g. if env is written in old API, + wrapper is written in new API, and the final step output is desired to be in old API. >>> import gymnasium as gym >>> env = gym.make("CartPole-v0") - >>> _ = env.reset() - >>> obs, rewards, done, info = step_api_compatibility(env.step(0), output_truncation_bool=False) - >>> obs, rewards, terminated, truncated, info = step_api_compatibility(env.step(0), output_truncation_bool=True) + >>> _, _ = env.reset() + >>> obs, reward, done, info = step_api_compatibility(env.step(0), output_truncation_bool=False) + >>> obs, reward, terminated, truncated, info = step_api_compatibility(env.step(0), output_truncation_bool=True) - >>> vec_env = gym.vector.make("CartPole-v0") - >>> _ = vec_env.reset() + >>> vec_env = gym.make_vec("CartPole-v0", vectorization_mode="sync") + >>> _, _ = vec_env.reset() >>> obs, rewards, dones, infos = step_api_compatibility(vec_env.step([0]), is_vector_env=True, output_truncation_bool=False) - >>> obs, rewards, terminated, truncated, info = step_api_compatibility(vec_env.step([0]), is_vector_env=True, output_truncation_bool=True) + >>> obs, rewards, terminations, truncations, infos = step_api_compatibility(vec_env.step([0]), is_vector_env=True, output_truncation_bool=True) + """ if output_truncation_bool: return convert_to_terminated_truncated_step_api(step_returns, is_vector_env) diff --git a/gymnasium/vector/__init__.py b/gymnasium/vector/__init__.py index 815aff7c2..e3421f840 100644 --- a/gymnasium/vector/__init__.py +++ b/gymnasium/vector/__init__.py @@ -1,85 +1,23 @@ -"""Module for vector environments.""" -from typing import Callable, Iterable, List, Optional, Union - -import gymnasium as gym -from gymnasium.core import Env +"""Experimental vector env API.""" from gymnasium.vector import utils from gymnasium.vector.async_vector_env import AsyncVectorEnv from gymnasium.vector.sync_vector_env import SyncVectorEnv -from gymnasium.vector.vector_env import VectorEnv, VectorEnvWrapper +from gymnasium.vector.vector_env import ( + VectorActionWrapper, + VectorEnv, + VectorObservationWrapper, + VectorRewardWrapper, + VectorWrapper, +) __all__ = [ - "AsyncVectorEnv", - "SyncVectorEnv", "VectorEnv", - "VectorEnvWrapper", - "make", + "VectorWrapper", + "VectorObservationWrapper", + "VectorActionWrapper", + "VectorRewardWrapper", + "SyncVectorEnv", + "AsyncVectorEnv", "utils", ] - - -def make( - id: str, - num_envs: int = 1, - asynchronous: bool = True, - wrappers: Optional[Union[Callable[[Env], Env], List[Callable[[Env], Env]]]] = None, - disable_env_checker: Optional[bool] = None, - **kwargs, -) -> VectorEnv: - """Create a vectorized environment from multiple copies of an environment, from its id. - - Args: - id: The environment ID. This must be a valid ID from the registry. - num_envs: Number of copies of the environment. - asynchronous: If `True`, wraps the environments in an :class:`AsyncVectorEnv` (which uses `multiprocessing` to run the environments in parallel). If ``False``, wraps the environments in a :class:`SyncVectorEnv`. - wrappers: If not ``None``, then apply the wrappers to each internal environment during creation. - disable_env_checker: If to run the env checker for the first environment only. None will default to the environment spec `disable_env_checker` parameter - (that is by default False), otherwise will run according to this argument (True = not run, False = run) - **kwargs: Keywords arguments applied during `gym.make` - - Returns: - The vectorized environment. - - Example: - >>> import gymnasium as gym - >>> env = gym.vector.make('CartPole-v1', num_envs=3) - >>> env.reset(seed=42) - (array([[ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], - [ 0.01522993, -0.04562247, -0.04799704, 0.03392126], - [-0.03774345, -0.02418869, -0.00942293, 0.0469184 ]], - dtype=float32), {}) - """ - gym.logger.warn( - "`gymnasium.vector.make(...)` is deprecated and will be replaced by `gymnasium.make_vec(...)` in v1.0" - ) - - def create_env(env_num: int) -> Callable[[], Env]: - """Creates an environment that can enable or disable the environment checker.""" - # If the env_num > 0 then disable the environment checker otherwise use the parameter - _disable_env_checker = True if env_num > 0 else disable_env_checker - - def _make_env() -> Env: - env = gym.envs.registration.make( - id, - disable_env_checker=_disable_env_checker, - **kwargs, - ) - if wrappers is not None: - if callable(wrappers): - env = wrappers(env) - elif isinstance(wrappers, Iterable) and all( - [callable(w) for w in wrappers] - ): - for wrapper in wrappers: - env = wrapper(env) - else: - raise NotImplementedError - return env - - return _make_env - - env_fns = [ - create_env(disable_env_checker or env_num > 0) for env_num in range(num_envs) - ] - return AsyncVectorEnv(env_fns) if asynchronous else SyncVectorEnv(env_fns) diff --git a/gymnasium/vector/async_vector_env.py b/gymnasium/vector/async_vector_env.py index e6a874ae5..ea724cf25 100644 --- a/gymnasium/vector/async_vector_env.py +++ b/gymnasium/vector/async_vector_env.py @@ -1,17 +1,19 @@ """An async vector environment.""" -import multiprocessing as mp +from __future__ import annotations + +import multiprocessing import sys import time from copy import deepcopy from enum import Enum -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from multiprocessing import Queue +from multiprocessing.connection import Connection +from typing import Any, Callable, Sequence import numpy as np -from numpy.typing import NDArray -import gymnasium as gym from gymnasium import logger -from gymnasium.core import Env, ObsType +from gymnasium.core import ActType, Env, ObsType, RenderFrame from gymnasium.error import ( AlreadyPendingCallError, ClosedEnvironmentError, @@ -20,6 +22,7 @@ from gymnasium.error import ( ) from gymnasium.vector.utils import ( CloudpickleWrapper, + batch_space, clear_mpi_env_vars, concatenate, create_empty_array, @@ -28,13 +31,15 @@ from gymnasium.vector.utils import ( read_from_shared_memory, write_to_shared_memory, ) -from gymnasium.vector.vector_env import VectorEnv +from gymnasium.vector.vector_env import ArrayType, VectorEnv -__all__ = ["AsyncVectorEnv"] +__all__ = ["AsyncVectorEnv", "AsyncState"] class AsyncState(Enum): + """The AsyncVectorEnv possible states given the different actions.""" + DEFAULT = "default" WAITING_RESET = "reset" WAITING_STEP = "step" @@ -48,39 +53,57 @@ class AsyncVectorEnv(VectorEnv): Example: >>> import gymnasium as gym - >>> env = gym.vector.AsyncVectorEnv([ + >>> envs = gym.make_vec("Pendulum-v1", num_envs=2, vectorization_mode="async") + >>> envs + AsyncVectorEnv(Pendulum-v1, num_envs=2) + >>> envs = gym.vector.AsyncVectorEnv([ ... lambda: gym.make("Pendulum-v1", g=9.81), ... lambda: gym.make("Pendulum-v1", g=1.62) ... ]) - >>> env.reset(seed=42) - (array([[-0.14995256, 0.9886932 , -0.12224312], - [ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32), {}) + >>> envs + AsyncVectorEnv(num_envs=2) + >>> observations, infos = envs.reset(seed=42) + >>> observations + array([[-0.14995256, 0.9886932 , -0.12224312], + [ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32) + >>> infos + {} + >>> _ = envs.action_space.seed(123) + >>> observations, rewards, terminations, truncations, infos = envs.step(envs.action_space.sample()) + >>> observations + array([[-0.1851753 , 0.98270553, 0.714599 ], + [ 0.6193494 , 0.7851154 , -1.0808398 ]], dtype=float32) + >>> rewards + array([-2.96495728, -1.00214607]) + >>> terminations + array([False, False]) + >>> truncations + array([False, False]) + >>> infos + {} """ def __init__( self, env_fns: Sequence[Callable[[], Env]], - observation_space: Optional[gym.Space] = None, - action_space: Optional[gym.Space] = None, shared_memory: bool = True, copy: bool = True, - context: Optional[str] = None, + context: str | None = None, daemon: bool = True, - worker: Optional[Callable] = None, + worker: Callable[ + [int, Callable[[], Env], Connection, Connection, bool, Queue], None + ] + | None = None, ): """Vectorized environment that runs multiple environments in parallel. Args: env_fns: Functions that create the environments. - observation_space: Observation space of a single environment. If ``None``, - then the observation space of the first environment is taken. - action_space: Action space of a single environment. If ``None``, - then the action space of the first environment is taken. shared_memory: If ``True``, then the observations from the worker processes are communicated back through shared variables. This can improve the efficiency if the observations are large (e.g. images). - copy: If ``True``, then the :meth:`~AsyncVectorEnv.reset` and :meth:`~AsyncVectorEnv.step` methods + copy: If ``True``, then the :meth:`AsyncVectorEnv.reset` and :meth:`AsyncVectorEnv.step` methods return a copy of the observations. - context: Context for `multiprocessing`_. If ``None``, then the default context is used. + context: Context for `multiprocessing`. If ``None``, then the default context is used. daemon: If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they will quit if the head process quits. However, ``daemon=True`` prevents subprocesses to spawn children, so for some environments you may want to have it set to ``False``. @@ -98,24 +121,33 @@ class AsyncVectorEnv(VectorEnv): ValueError: If observation_space is a custom space (i.e. not a default space in Gym, such as gymnasium.spaces.Box, gymnasium.spaces.Discrete, or gymnasium.spaces.Dict) and shared_memory is True. """ - ctx = mp.get_context(context) self.env_fns = env_fns self.shared_memory = shared_memory self.copy = copy - dummy_env = env_fns[0]() - self.metadata = dummy_env.metadata - if (observation_space is None) or (action_space is None): - observation_space = observation_space or dummy_env.observation_space - action_space = action_space or dummy_env.action_space + self.num_envs = len(env_fns) + + # This would be nice to get rid of, but without it there's a deadlock between shared memory and pipes + # Create a dummy environment to gather the metadata and observation / action space of the environment + dummy_env = env_fns[0]() + + # As we support `make_vec(spec)` then we can't include a `spec = dummy_env.spec` as this doesn't guarantee we can actual recreate the vector env. + self.metadata = dummy_env.metadata + self.render_mode = dummy_env.render_mode + + self.single_observation_space = dummy_env.observation_space + self.single_action_space = dummy_env.action_space + + self.observation_space = batch_space( + self.single_observation_space, self.num_envs + ) + self.action_space = batch_space(self.single_action_space, self.num_envs) + dummy_env.close() del dummy_env - super().__init__( - num_envs=len(env_fns), - observation_space=observation_space, - action_space=action_space, - ) + # Generate the multiprocessing context for the observation buffer + ctx = multiprocessing.get_context(context) if self.shared_memory: try: _obs_buffer = create_shared_memory( @@ -126,12 +158,9 @@ class AsyncVectorEnv(VectorEnv): ) except CustomSpaceError as e: raise ValueError( - "Using `shared_memory=True` in `AsyncVectorEnv` " - "is incompatible with non-standard Gymnasium observation spaces " - "(i.e. custom spaces inheriting from `gymnasium.Space`), and is " - "only compatible with default Gymnasium spaces (e.g. `Box`, " - "`Tuple`, `Dict`) for batching. Set `shared_memory=False` " - "if you use custom observation spaces." + "Using `shared_memory=True` in `AsyncVectorEnv` is incompatible with non-standard Gymnasium observation spaces (i.e. custom spaces inheriting from `gymnasium.Space`), " + "and is only compatible with default Gymnasium spaces (e.g. `Box`, `Tuple`, `Dict`) for batching. " + "Set `shared_memory=False` if you use custom observation spaces." ) from e else: _obs_buffer = None @@ -141,8 +170,7 @@ class AsyncVectorEnv(VectorEnv): self.parent_pipes, self.processes = [], [] self.error_queue = ctx.Queue() - target = _worker_shared_memory if self.shared_memory else _worker - target = worker or target + target = worker or _async_worker with clear_mpi_env_vars(): for idx, env_fn in enumerate(self.env_fns): parent_pipe, child_pipe = ctx.Pipe() @@ -169,10 +197,28 @@ class AsyncVectorEnv(VectorEnv): self._state = AsyncState.DEFAULT self._check_spaces() + def reset( + self, + *, + seed: int | list[int] | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[ObsType, dict[str, Any]]: + """Resets all sub-environments in parallel and return a batch of concatenated observations and info. + + Args: + seed: The environment reset seeds + options: If to return the options + + Returns: + A batch of observations and info from the vectorized environment. + """ + self.reset_async(seed=seed, options=options) + return self.reset_wait() + def reset_async( self, - seed: Optional[Union[int, List[int]]] = None, - options: Optional[dict] = None, + seed: int | list[int] | None = None, + options: dict | None = None, ): """Send calls to the :obj:`reset` methods of the sub-environments. @@ -192,38 +238,29 @@ class AsyncVectorEnv(VectorEnv): if seed is None: seed = [None for _ in range(self.num_envs)] - if isinstance(seed, int): + elif isinstance(seed, int): seed = [seed + i for i in range(self.num_envs)] assert len(seed) == self.num_envs if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( f"Calling `reset_async` while waiting for a pending call to `{self._state.value}` to complete", - self._state.value, + str(self._state.value), ) - for pipe, single_seed in zip(self.parent_pipes, seed): - single_kwargs = {} - if single_seed is not None: - single_kwargs["seed"] = single_seed - if options is not None: - single_kwargs["options"] = options - - pipe.send(("reset", single_kwargs)) + for pipe, env_seed in zip(self.parent_pipes, seed): + env_kwargs = {"seed": env_seed, "options": options} + pipe.send(("reset", env_kwargs)) self._state = AsyncState.WAITING_RESET def reset_wait( self, - timeout: Optional[Union[int, float]] = None, - seed: Optional[int] = None, - options: Optional[dict] = None, - ) -> Union[ObsType, Tuple[ObsType, dict]]: + timeout: int | float | None = None, + ) -> tuple[ObsType, dict[str, Any]]: """Waits for the calls triggered by :meth:`reset_async` to finish and returns the results. Args: - timeout: Number of seconds before the call to `reset_wait` times out. If `None`, the call to `reset_wait` never times out. - seed: ignored - options: ignored + timeout: Number of seconds before the call to ``reset_wait`` times out. If `None`, the call to ``reset_wait`` never times out. Returns: A tuple of batched observations and list of dictionaries @@ -240,15 +277,14 @@ class AsyncVectorEnv(VectorEnv): AsyncState.WAITING_RESET.value, ) - if not self._poll(timeout): + if not self._poll_pipe_envs(timeout): self._state = AsyncState.DEFAULT - raise mp.TimeoutError( + raise multiprocessing.TimeoutError( f"The call to `reset_wait` has timed out after {timeout} second(s)." ) results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) self._raise_if_errors(successes) - self._state = AsyncState.DEFAULT infos = {} results, info_data = zip(*results) @@ -260,13 +296,28 @@ class AsyncVectorEnv(VectorEnv): self.single_observation_space, results, self.observations ) + self._state = AsyncState.DEFAULT return (deepcopy(self.observations) if self.copy else self.observations), infos - def step_async(self, actions: np.ndarray): - """Send the calls to :obj:`step` to each sub-environment. + def step( + self, actions: ActType + ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]: + """Take an action for each parallel environment. Args: - actions: Batch of actions. element of :attr:`~VectorEnv.action_space` + actions: element of :attr:`action_space` batch of actions. + + Returns: + Batch of (observations, rewards, terminations, truncations, infos) + """ + self.step_async(actions) + return self.step_wait() + + def step_async(self, actions: np.ndarray): + """Send the calls to :meth:`Env.step` to each sub-environment. + + Args: + actions: Batch of actions. element of :attr:`VectorEnv.action_space` Raises: ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called). @@ -279,17 +330,17 @@ class AsyncVectorEnv(VectorEnv): if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( f"Calling `step_async` while waiting for a pending call to `{self._state.value}` to complete.", - self._state.value, + str(self._state.value), ) - actions = iterate(self.action_space, actions) - for pipe, action in zip(self.parent_pipes, actions): + iter_actions = iterate(self.action_space, actions) + for pipe, action in zip(self.parent_pipes, iter_actions): pipe.send(("step", action)) self._state = AsyncState.WAITING_STEP def step_wait( - self, timeout: Optional[Union[int, float]] = None - ) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]: + self, timeout: int | float | None = None + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, dict]: """Wait for the calls to :obj:`step` in each sub-environment to finish. Args: @@ -310,44 +361,61 @@ class AsyncVectorEnv(VectorEnv): AsyncState.WAITING_STEP.value, ) - if not self._poll(timeout): + if not self._poll_pipe_envs(timeout): self._state = AsyncState.DEFAULT - raise mp.TimeoutError( + raise multiprocessing.TimeoutError( f"The call to `step_wait` has timed out after {timeout} second(s)." ) - observations_list, rewards, terminateds, truncateds, infos = [], [], [], [], {} + observations, rewards, terminations, truncations, infos = [], [], [], [], {} successes = [] - for i, pipe in enumerate(self.parent_pipes): - result, success = pipe.recv() + for env_idx, pipe in enumerate(self.parent_pipes): + env_step_return, success = pipe.recv() + successes.append(success) if success: - obs, rew, terminated, truncated, info = result - - observations_list.append(obs) - rewards.append(rew) - terminateds.append(terminated) - truncateds.append(truncated) - infos = self._add_info(infos, info, i) + observations.append(env_step_return[0]) + rewards.append(env_step_return[1]) + terminations.append(env_step_return[2]) + truncations.append(env_step_return[3]) + infos = self._add_info(infos, env_step_return[4], env_idx) self._raise_if_errors(successes) - self._state = AsyncState.DEFAULT if not self.shared_memory: self.observations = concatenate( self.single_observation_space, - observations_list, + observations, self.observations, ) + self._state = AsyncState.DEFAULT return ( deepcopy(self.observations) if self.copy else self.observations, - np.array(rewards), - np.array(terminateds, dtype=np.bool_), - np.array(truncateds, dtype=np.bool_), + np.array(rewards, dtype=np.float64), + np.array(terminations, dtype=np.bool_), + np.array(truncations, dtype=np.bool_), infos, ) + def call(self, name: str, *args: Any, **kwargs: Any) -> tuple[Any, ...]: + """Call a method from each parallel environment with args and kwargs. + + Args: + name (str): Name of the method or property to call. + *args: Position arguments to apply to the method call. + **kwargs: Keyword arguments to apply to the method call. + + Returns: + List of the results of the individual calls to the method or property for each environment. + """ + self.call_async(name, *args, **kwargs) + return self.call_wait() + + def render(self) -> tuple[RenderFrame, ...] | None: + """Returns a list of rendered frames from the environments.""" + return self.call("render") + def call_async(self, name: str, *args, **kwargs): """Calls the method with name asynchronously and apply args and kwargs to the method. @@ -363,28 +431,27 @@ class AsyncVectorEnv(VectorEnv): self._assert_is_running() if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - "Calling `call_async` while waiting " - f"for a pending call to `{self._state.value}` to complete.", - self._state.value, + f"Calling `call_async` while waiting for a pending call to `{self._state.value}` to complete.", + str(self._state.value), ) for pipe in self.parent_pipes: pipe.send(("_call", (name, args, kwargs))) self._state = AsyncState.WAITING_CALL - def call_wait(self, timeout: Optional[Union[int, float]] = None) -> list: + def call_wait(self, timeout: int | float | None = None) -> tuple[Any, ...]: """Calls all parent pipes and waits for the results. Args: - timeout: Number of seconds before the call to `step_wait` times out. - If `None` (default), the call to `step_wait` never times out. + timeout: Number of seconds before the call to :meth:`step_wait` times out. + If ``None`` (default), the call to :meth:`step_wait` never times out. Returns: List of the results of the individual calls to the method or property for each environment. Raises: - NoAsyncCallError: Calling `call_wait` without any prior call to `call_async`. - TimeoutError: The call to `call_wait` has timed out after timeout second(s). + NoAsyncCallError: Calling :meth:`call_wait` without any prior call to :meth:`call_async`. + TimeoutError: The call to :meth:`call_wait` has timed out after timeout second(s). """ self._assert_is_running() if self._state != AsyncState.WAITING_CALL: @@ -393,9 +460,9 @@ class AsyncVectorEnv(VectorEnv): AsyncState.WAITING_CALL.value, ) - if not self._poll(timeout): + if not self._poll_pipe_envs(timeout): self._state = AsyncState.DEFAULT - raise mp.TimeoutError( + raise multiprocessing.TimeoutError( f"The call to `call_wait` has timed out after {timeout} second(s)." ) @@ -405,7 +472,18 @@ class AsyncVectorEnv(VectorEnv): return results - def set_attr(self, name: str, values: Union[list, tuple, object]): + def get_attr(self, name: str): + """Get a property from each parallel environment. + + Args: + name (str): Name of the property to be get from each individual environment. + + Returns: + The property with name + """ + return self.call(name) + + def set_attr(self, name: str, values: list[Any] | tuple[Any] | object): """Sets an attribute of the sub-environments. Args: @@ -416,23 +494,21 @@ class AsyncVectorEnv(VectorEnv): Raises: ValueError: Values must be a list or tuple with length equal to the number of environments. - AlreadyPendingCallError: Calling `set_attr` while waiting for a pending call to complete. + AlreadyPendingCallError: Calling :meth:`set_attr` while waiting for a pending call to complete. """ self._assert_is_running() if not isinstance(values, (list, tuple)): values = [values for _ in range(self.num_envs)] if len(values) != self.num_envs: raise ValueError( - "Values must be a list or tuple with length equal to the " - f"number of environments. Got `{len(values)}` values for " - f"{self.num_envs} environments." + "Values must be a list or tuple with length equal to the number of environments. " + f"Got `{len(values)}` values for {self.num_envs} environments." ) if self._state != AsyncState.DEFAULT: raise AlreadyPendingCallError( - "Calling `set_attr` while waiting " - f"for a pending call to `{self._state.value}` to complete.", - self._state.value, + f"Calling `set_attr` while waiting for a pending call to `{self._state.value}` to complete.", + str(self._state.value), ) for pipe, value in zip(self.parent_pipes, values): @@ -440,9 +516,7 @@ class AsyncVectorEnv(VectorEnv): _, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) self._raise_if_errors(successes) - def close_extras( - self, timeout: Optional[Union[int, float]] = None, terminate: bool = False - ): + def close_extras(self, timeout: int | float | None = None, terminate: bool = False): """Close the environments & clean up the extra resources (processes and pipes). Args: @@ -462,7 +536,7 @@ class AsyncVectorEnv(VectorEnv): ) function = getattr(self, f"{self._state.value}_wait") function(timeout) - except mp.TimeoutError: + except multiprocessing.TimeoutError: terminate = True if terminate: @@ -483,14 +557,16 @@ class AsyncVectorEnv(VectorEnv): for process in self.processes: process.join() - def _poll(self, timeout=None): + def _poll_pipe_envs(self, timeout: int | None = None): self._assert_is_running() + if timeout is None: return True + end_time = time.perf_counter() + timeout - delta = None for pipe in self.parent_pipes: delta = max(end_time - time.perf_counter(), 0) + if pipe is None: return False if pipe.closed or (not pipe.poll(delta)): @@ -500,22 +576,23 @@ class AsyncVectorEnv(VectorEnv): def _check_spaces(self): self._assert_is_running() spaces = (self.single_observation_space, self.single_action_space) + for pipe in self.parent_pipes: pipe.send(("_check_spaces", spaces)) + results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) self._raise_if_errors(successes) same_observation_spaces, same_action_spaces = zip(*results) + if not all(same_observation_spaces): raise RuntimeError( - "Some environments have an observation space different from " - f"`{self.single_observation_space}`. In order to batch observations, " - "the observation spaces from all environments must be equal." + f"Some environments have an observation space different from `{self.single_observation_space}`. " + "In order to batch observations, the observation spaces from all environments must be equal." ) if not all(same_action_spaces): raise RuntimeError( - "Some environments have an action space different from " - f"`{self.single_action_space}`. In order to batch actions, the " - "action spaces from all environments must be equal." + f"Some environments have an action space different from `{self.single_action_space}`. " + "In order to batch actions, the action spaces from all environments must be equal." ) def _assert_is_running(self): @@ -524,7 +601,7 @@ class AsyncVectorEnv(VectorEnv): f"Trying to operate on `{type(self).__name__}`, after a call to `close()`." ) - def _raise_if_errors(self, successes): + def _raise_if_errors(self, successes: list[bool]): if all(successes): return @@ -532,10 +609,12 @@ class AsyncVectorEnv(VectorEnv): assert num_errors > 0 for i in range(num_errors): index, exctype, value = self.error_queue.get() + logger.error( f"Received the following error from Worker-{index}: {exctype.__name__}: {value}" ) logger.error(f"Shutting down Worker-{index}.") + self.parent_pipes[index].close() self.parent_pipes[index] = None @@ -549,17 +628,32 @@ class AsyncVectorEnv(VectorEnv): self.close(terminate=True) -def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue): - assert shared_memory is None +def _async_worker( + index: int, + env_fn: callable, + pipe: Connection, + parent_pipe: Connection, + shared_memory: bool, + error_queue: Queue, +): env = env_fn() + observation_space = env.observation_space + action_space = env.action_space + parent_pipe.close() + try: while True: command, data = pipe.recv() + if command == "reset": observation, info = env.reset(**data) + if shared_memory: + write_to_shared_memory( + observation_space, index, observation, shared_memory + ) + observation = None pipe.send(((observation, info), True)) - elif command == "step": ( observation, @@ -573,112 +667,43 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue): observation, info = env.reset() info["final_observation"] = old_observation info["final_info"] = old_info + + if shared_memory: + write_to_shared_memory( + observation_space, index, observation, shared_memory + ) + observation = None + pipe.send(((observation, reward, terminated, truncated, info), True)) - elif command == "seed": - env.seed(data) - pipe.send((None, True)) elif command == "close": pipe.send((None, True)) break elif command == "_call": name, args, kwargs = data - if name in ["reset", "step", "seed", "close"]: + if name in ["reset", "step", "close", "set_wrapper_attr"]: raise ValueError( - f"Trying to call function `{name}` with " - f"`_call`. Use `{name}` directly instead." + f"Trying to call function `{name}` with `call`, use `{name}` directly instead." ) - function = getattr(env, name) - if callable(function): - pipe.send((function(*args, **kwargs), True)) + + attr = env.get_wrapper_attr(name) + if callable(attr): + pipe.send((attr(*args, **kwargs), True)) else: - pipe.send((function, True)) + pipe.send((attr, True)) elif command == "_setattr": name, value = data - setattr(env, name, value) + env.set_wrapper_attr(name, value) pipe.send((None, True)) elif command == "_check_spaces": pipe.send( ( - (data[0] == env.observation_space, data[1] == env.action_space), + (data[0] == observation_space, data[1] == action_space), True, ) ) else: raise RuntimeError( - f"Received unknown command `{command}`. Must " - "be one of {`reset`, `step`, `seed`, `close`, `_call`, " - "`_setattr`, `_check_spaces`}." - ) - except (KeyboardInterrupt, Exception): - error_queue.put((index,) + sys.exc_info()[:2]) - pipe.send((None, False)) - finally: - env.close() - - -def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error_queue): - assert shared_memory is not None - env = env_fn() - observation_space = env.observation_space - parent_pipe.close() - try: - while True: - command, data = pipe.recv() - if command == "reset": - observation, info = env.reset(**data) - write_to_shared_memory( - observation_space, index, observation, shared_memory - ) - pipe.send(((None, info), True)) - - elif command == "step": - ( - observation, - reward, - terminated, - truncated, - info, - ) = env.step(data) - if terminated or truncated: - old_observation, old_info = observation, info - observation, info = env.reset() - info["final_observation"] = old_observation - info["final_info"] = old_info - write_to_shared_memory( - observation_space, index, observation, shared_memory - ) - pipe.send(((None, reward, terminated, truncated, info), True)) - elif command == "seed": - env.seed(data) - pipe.send((None, True)) - elif command == "close": - pipe.send((None, True)) - break - elif command == "_call": - name, args, kwargs = data - if name in ["reset", "step", "seed", "close"]: - raise ValueError( - f"Trying to call function `{name}` with " - f"`_call`. Use `{name}` directly instead." - ) - function = getattr(env, name) - if callable(function): - pipe.send((function(*args, **kwargs), True)) - else: - pipe.send((function, True)) - elif command == "_setattr": - name, value = data - setattr(env, name, value) - pipe.send((None, True)) - elif command == "_check_spaces": - pipe.send( - ((data[0] == observation_space, data[1] == env.action_space), True) - ) - else: - raise RuntimeError( - f"Received unknown command `{command}`. Must " - "be one of {`reset`, `step`, `seed`, `close`, `_call`, " - "`_setattr`, `_check_spaces`}." + f"Received unknown command `{command}`. Must be one of [`reset`, `step`, `close`, `_call`, `_setattr`, `_check_spaces`]." ) except (KeyboardInterrupt, Exception): error_queue.put((index,) + sys.exc_info()[:2]) diff --git a/gymnasium/vector/sync_vector_env.py b/gymnasium/vector/sync_vector_env.py index fdcc6f598..eed598bd4 100644 --- a/gymnasium/vector/sync_vector_env.py +++ b/gymnasium/vector/sync_vector_env.py @@ -1,14 +1,15 @@ -"""A synchronous vector environment.""" +"""Implementation of a synchronous (for loop) vectorization method of any environment.""" +from __future__ import annotations + from copy import deepcopy -from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Iterator, Sequence import numpy as np -from numpy.typing import NDArray from gymnasium import Env -from gymnasium.spaces import Space -from gymnasium.vector.utils import concatenate, create_empty_array, iterate -from gymnasium.vector.vector_env import VectorEnv +from gymnasium.core import ActType, ObsType, RenderFrame +from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate +from gymnasium.vector.vector_env import ArrayType, VectorEnv __all__ = ["SyncVectorEnv"] @@ -19,156 +20,175 @@ class SyncVectorEnv(VectorEnv): Example: >>> import gymnasium as gym - >>> env = gym.vector.SyncVectorEnv([ + >>> envs = gym.make_vec("Pendulum-v1", num_envs=2, vectorization_mode="sync") + >>> envs + SyncVectorEnv(Pendulum-v1, num_envs=2) + >>> envs = gym.vector.SyncVectorEnv([ ... lambda: gym.make("Pendulum-v1", g=9.81), ... lambda: gym.make("Pendulum-v1", g=1.62) ... ]) - >>> env.reset(seed=42) - (array([[-0.14995256, 0.9886932 , -0.12224312], - [ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32), {}) + >>> envs + SyncVectorEnv(num_envs=2) + >>> obs, infos = envs.reset(seed=42) + >>> obs + array([[-0.14995256, 0.9886932 , -0.12224312], + [ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32) + >>> infos + {} + >>> _ = envs.action_space.seed(42) + >>> actions = envs.action_space.sample() + >>> obs, rewards, terminates, truncates, infos = envs.step(actions) + >>> obs + array([[-0.1878752 , 0.98219293, 0.7695615 ], + [ 0.6102389 , 0.79221743, -0.8498053 ]], dtype=float32) + >>> rewards + array([-2.96562607, -0.99902063]) + >>> terminates + array([False, False]) + >>> truncates + array([False, False]) + >>> infos + {} + >>> envs.close() """ def __init__( self, - env_fns: Iterable[Callable[[], Env]], - observation_space: Space = None, - action_space: Space = None, + env_fns: Iterator[Callable[[], Env]] | Sequence[Callable[[], Env]], copy: bool = True, ): """Vectorized environment that serially runs multiple environments. Args: env_fns: iterable of callable functions that create the environments. - observation_space: Observation space of a single environment. If ``None``, - then the observation space of the first environment is taken. - action_space: Action space of a single environment. If ``None``, - then the action space of the first environment is taken. copy: If ``True``, then the :meth:`reset` and :meth:`step` methods return a copy of the observations. Raises: RuntimeError: If the observation space of some sub-environment does not match observation_space (or, by default, the observation space of the first sub-environment). """ - self.env_fns = env_fns - self.envs = [env_fn() for env_fn in env_fns] self.copy = copy + self.env_fns = env_fns + + # Initialise all sub-environments + self.envs = [env_fn() for env_fn in env_fns] + + # Define core attributes using the sub-environments + # As we support `make_vec(spec)` then we can't include a `spec = self.envs[0].spec` as this doesn't guarantee we can actual recreate the vector env. + self.num_envs = len(self.envs) self.metadata = self.envs[0].metadata + self.render_mode = self.envs[0].render_mode - if (observation_space is None) or (action_space is None): - observation_space = observation_space or self.envs[0].observation_space - action_space = action_space or self.envs[0].action_space - super().__init__( - num_envs=len(self.envs), - observation_space=observation_space, - action_space=action_space, - ) - + # Initialises the single spaces from the sub-environments + self.single_observation_space = self.envs[0].observation_space + self.single_action_space = self.envs[0].action_space self._check_spaces() - self.observations = create_empty_array( + + # Initialise the obs and action space based on the single versions and num of sub-environments + self.observation_space = batch_space( + self.single_observation_space, self.num_envs + ) + self.action_space = batch_space(self.single_action_space, self.num_envs) + + # Initialise attributes used in `step` and `reset` + self._observations = create_empty_array( self.single_observation_space, n=self.num_envs, fn=np.zeros ) self._rewards = np.zeros((self.num_envs,), dtype=np.float64) - self._terminateds = np.zeros((self.num_envs,), dtype=np.bool_) - self._truncateds = np.zeros((self.num_envs,), dtype=np.bool_) - self._actions = None + self._terminations = np.zeros((self.num_envs,), dtype=np.bool_) + self._truncations = np.zeros((self.num_envs,), dtype=np.bool_) - def seed(self, seed: Optional[Union[int, Sequence[int]]] = None): - """Sets the seed in all sub-environments. - - Args: - seed: The seed - """ - super().seed(seed=seed) - if seed is None: - seed = [None for _ in range(self.num_envs)] - if isinstance(seed, int): - seed = [seed + i for i in range(self.num_envs)] - assert len(seed) == self.num_envs - - for env, single_seed in zip(self.envs, seed): - env.seed(single_seed) - - def reset_wait( + def reset( self, - seed: Optional[Union[int, List[int]]] = None, - options: Optional[dict] = None, - ): - """Waits for the calls triggered by :meth:`reset_async` to finish and returns the results. + *, + seed: int | list[int] | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[ObsType, dict[str, Any]]: + """Resets each of the sub-environments and concatenate the results together. Args: - seed: The reset environment seed - options: Option information for the environment reset + seed: Seeds used to reset the sub-environments, either + * ``None`` - random seeds for all environment + * ``int`` - ``[seed, seed+1, ..., seed+n]`` + * List of ints - ``[1, 2, 3, ..., n]`` + options: Option information used for each sub-environment Returns: - The reset observation of the environment and reset information + Concatenated observations and info from each sub-environment """ if seed is None: seed = [None for _ in range(self.num_envs)] - if isinstance(seed, int): + elif isinstance(seed, int): seed = [seed + i for i in range(self.num_envs)] assert len(seed) == self.num_envs - self._terminateds[:] = False - self._truncateds[:] = False - observations = [] - infos = {} + self._terminations = np.zeros((self.num_envs,), dtype=np.bool_) + self._truncations = np.zeros((self.num_envs,), dtype=np.bool_) + + observations, infos = [], {} for i, (env, single_seed) in enumerate(zip(self.envs, seed)): - kwargs = {} - if single_seed is not None: - kwargs["seed"] = single_seed - if options is not None: - kwargs["options"] = options + env_obs, env_info = env.reset(seed=single_seed, options=options) - observation, info = env.reset(**kwargs) - observations.append(observation) - infos = self._add_info(infos, info, i) + observations.append(env_obs) + infos = self._add_info(infos, env_info, i) - self.observations = concatenate( - self.single_observation_space, observations, self.observations + # Concatenate the observations + self._observations = concatenate( + self.single_observation_space, observations, self._observations ) - return (deepcopy(self.observations) if self.copy else self.observations), infos - def step_async(self, actions): - """Sets :attr:`_actions` for use by the :meth:`step_wait` by converting the ``actions`` to an iterable version.""" - self._actions = iterate(self.action_space, actions) + return deepcopy(self._observations) if self.copy else self._observations, infos - def step_wait(self) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]: + def step( + self, actions: ActType + ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]: """Steps through each of the environments returning the batched results. Returns: The batched environment step results """ + actions = iterate(self.action_space, actions) + observations, infos = [], {} - for i, (env, action) in enumerate(zip(self.envs, self._actions)): + for i, (env, action) in enumerate(zip(self.envs, actions)): ( - observation, + env_obs, self._rewards[i], - self._terminateds[i], - self._truncateds[i], - info, + self._terminations[i], + self._truncations[i], + env_info, ) = env.step(action) - if self._terminateds[i] or self._truncateds[i]: - old_observation, old_info = observation, info - observation, info = env.reset() - info["final_observation"] = old_observation - info["final_info"] = old_info - observations.append(observation) - infos = self._add_info(infos, info, i) - self.observations = concatenate( - self.single_observation_space, observations, self.observations + # If sub-environments terminates or truncates then save the obs and info to the batched info + if self._terminations[i] or self._truncations[i]: + old_observation, old_info = env_obs, env_info + env_obs, env_info = env.reset() + + env_info["final_observation"] = old_observation + env_info["final_info"] = old_info + + observations.append(env_obs) + infos = self._add_info(infos, env_info, i) + + # Concatenate the observations + self._observations = concatenate( + self.single_observation_space, observations, self._observations ) return ( - deepcopy(self.observations) if self.copy else self.observations, + deepcopy(self._observations) if self.copy else self._observations, np.copy(self._rewards), - np.copy(self._terminateds), - np.copy(self._truncateds), + np.copy(self._terminations), + np.copy(self._truncations), infos, ) - def call(self, name, *args, **kwargs) -> tuple: - """Calls the method with name and applies args and kwargs. + def render(self) -> tuple[RenderFrame, ...] | None: + """Returns the rendered frames from the environments.""" + return tuple(env.render() for env in self.envs) + + def call(self, name: str, *args: Any, **kwargs: Any) -> tuple[Any, ...]: + """Calls a sub-environment method with name and applies args and kwargs. Args: name: The method name @@ -180,7 +200,8 @@ class SyncVectorEnv(VectorEnv): """ results = [] for env in self.envs: - function = getattr(env, name) + function = env.get_wrapper_attr(name) + if callable(function): results.append(function(*args, **kwargs)) else: @@ -188,7 +209,18 @@ class SyncVectorEnv(VectorEnv): return tuple(results) - def set_attr(self, name: str, values: Union[list, tuple, Any]): + def get_attr(self, name: str) -> Any: + """Get a property from each parallel environment. + + Args: + name (str): Name of the property to get from each individual environment. + + Returns: + The property with name + """ + return self.call(name) + + def set_attr(self, name: str, values: list[Any] | tuple[Any, ...] | Any): """Sets an attribute of the sub-environments. Args: @@ -202,34 +234,33 @@ class SyncVectorEnv(VectorEnv): """ if not isinstance(values, (list, tuple)): values = [values for _ in range(self.num_envs)] + if len(values) != self.num_envs: raise ValueError( - "Values must be a list or tuple with length equal to the " - f"number of environments. Got `{len(values)}` values for " - f"{self.num_envs} environments." + "Values must be a list or tuple with length equal to the number of environments. " + f"Got `{len(values)}` values for {self.num_envs} environments." ) for env, value in zip(self.envs, values): - setattr(env, name, value) + env.set_wrapper_attr(name, value) - def close_extras(self, **kwargs): + def close_extras(self, **kwargs: Any): """Close the environments.""" [env.close() for env in self.envs] def _check_spaces(self) -> bool: + """Check that each of the environments obs and action spaces are equivalent to the single obs and action space.""" for env in self.envs: if not (env.observation_space == self.single_observation_space): raise RuntimeError( - "Some environments have an observation space different from " - f"`{self.single_observation_space}`. In order to batch observations, " - "the observation spaces from all environments must be equal." + f"Some environments have an observation space different from `{self.single_observation_space}`. " + "In order to batch observations, the observation spaces from all environments must be equal." ) if not (env.action_space == self.single_action_space): raise RuntimeError( - "Some environments have an action space different from " - f"`{self.single_action_space}`. In order to batch actions, the " - "action spaces from all environments must be equal." + f"Some environments have an action space different from `{self.single_action_space}`. " + "In order to batch actions, the action spaces from all environments must be equal." ) return True diff --git a/gymnasium/vector/utils/__init__.py b/gymnasium/vector/utils/__init__.py index 30b320460..a0ad58c3e 100644 --- a/gymnasium/vector/utils/__init__.py +++ b/gymnasium/vector/utils/__init__.py @@ -1,26 +1,27 @@ -"""Module for gymnasium vector utils.""" +"""Module for gymnasium experimental vector utility functions.""" + from gymnasium.vector.utils.misc import CloudpickleWrapper, clear_mpi_env_vars -from gymnasium.vector.utils.numpy_utils import concatenate, create_empty_array from gymnasium.vector.utils.shared_memory import ( create_shared_memory, read_from_shared_memory, write_to_shared_memory, ) -from gymnasium.vector.utils.spaces import ( - _BaseGymSpaces, # pyright: ignore[reportPrivateUsage] +from gymnasium.vector.utils.space_utils import ( + batch_space, + concatenate, + create_empty_array, + iterate, ) -from gymnasium.vector.utils.spaces import BaseGymSpaces, batch_space, iterate __all__ = [ - "CloudpickleWrapper", - "clear_mpi_env_vars", + "batch_space", + "iterate", "concatenate", "create_empty_array", "create_shared_memory", "read_from_shared_memory", "write_to_shared_memory", - "BaseGymSpaces", - "batch_space", - "iterate", + "CloudpickleWrapper", + "clear_mpi_env_vars", ] diff --git a/gymnasium/vector/utils/misc.py b/gymnasium/vector/utils/misc.py index c8cd1f368..51f283a4b 100644 --- a/gymnasium/vector/utils/misc.py +++ b/gymnasium/vector/utils/misc.py @@ -39,7 +39,7 @@ class CloudpickleWrapper: def clear_mpi_env_vars(): """Clears the MPI of environment variables. - `from mpi4py import MPI` will call `MPI_Init` by default. + ``from mpi4py import MPI`` will call ``MPI_Init`` by default. If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang. diff --git a/gymnasium/vector/utils/numpy_utils.py b/gymnasium/vector/utils/numpy_utils.py deleted file mode 100644 index f8f862c81..000000000 --- a/gymnasium/vector/utils/numpy_utils.py +++ /dev/null @@ -1,146 +0,0 @@ -"""Numpy utility functions: concatenate space samples and create empty array.""" -from collections import OrderedDict -from functools import singledispatch -from typing import Callable, Iterable, Union - -import numpy as np - -from gymnasium.spaces import ( - Box, - Dict, - Discrete, - MultiBinary, - MultiDiscrete, - Space, - Tuple, -) - - -__all__ = ["concatenate", "create_empty_array"] - - -@singledispatch -def concatenate( - space: Space, items: Iterable, out: Union[tuple, dict, np.ndarray] -) -> Union[tuple, dict, np.ndarray]: - """Concatenate multiple samples from space into a single object. - - Args: - space: Observation space of a single environment in the vectorized environment. - items: Samples to be concatenated. - out: The output object. This object is a (possibly nested) numpy array. - - Returns: - The output object. This object is a (possibly nested) numpy array. - - Raises: - ValueError: Space is not a valid :class:`gym.Space` instance - - Example: - >>> from gymnasium.spaces import Box - >>> import numpy as np - >>> space = Box(low=0, high=1, shape=(3,), seed=42, dtype=np.float32) - >>> out = np.zeros((2, 3), dtype=np.float32) - >>> items = [space.sample() for _ in range(2)] - >>> concatenate(space, items, out) - array([[0.77395606, 0.43887845, 0.85859793], - [0.697368 , 0.09417735, 0.97562236]], dtype=float32) - """ - raise ValueError( - f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance." - ) - - -@concatenate.register(Box) -@concatenate.register(Discrete) -@concatenate.register(MultiDiscrete) -@concatenate.register(MultiBinary) -def _concatenate_base(space, items, out): - return np.stack(items, axis=0, out=out) - - -@concatenate.register(Tuple) -def _concatenate_tuple(space, items, out): - return tuple( - concatenate(subspace, [item[i] for item in items], out[i]) - for (i, subspace) in enumerate(space.spaces) - ) - - -@concatenate.register(Dict) -def _concatenate_dict(space, items, out): - return OrderedDict( - [ - (key, concatenate(subspace, [item[key] for item in items], out[key])) - for (key, subspace) in space.spaces.items() - ] - ) - - -@concatenate.register(Space) -def _concatenate_custom(space, items, out): - return tuple(items) - - -@singledispatch -def create_empty_array( - space: Space, n: int = 1, fn: Callable[..., np.ndarray] = np.zeros -) -> Union[tuple, dict, np.ndarray]: - """Create an empty (possibly nested) numpy array. - - Args: - space: Observation space of a single environment in the vectorized environment. - n: Number of environments in the vectorized environment. If `None`, creates an empty sample from `space`. - fn: Function to apply when creating the empty numpy array. Examples of such functions are `np.empty` or `np.zeros`. - - Returns: - The output object. This object is a (possibly nested) numpy array. - - Raises: - ValueError: Space is not a valid :class:`gym.Space` instance - - Example: - >>> from gymnasium.spaces import Box, Dict - >>> import numpy as np - >>> space = Dict({ - ... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32), - ... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)}) - >>> create_empty_array(space, n=2, fn=np.zeros) - OrderedDict([('position', array([[0., 0., 0.], - [0., 0., 0.]], dtype=float32)), ('velocity', array([[0., 0.], - [0., 0.]], dtype=float32))]) - """ - raise ValueError( - f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance." - ) - - -# It is possible for the some of the Box low to be greater than 0, then array is not in space -@create_empty_array.register(Box) -# If the Discrete start > 0 or start + length < 0 then array is not in space -@create_empty_array.register(Discrete) -@create_empty_array.register(MultiDiscrete) -@create_empty_array.register(MultiBinary) -def _create_empty_array_base(space, n=1, fn=np.zeros): - shape = space.shape if (n is None) else (n,) + space.shape - return fn(shape, dtype=space.dtype) - - -@create_empty_array.register(Tuple) -def _create_empty_array_tuple(space, n=1, fn=np.zeros): - return tuple(create_empty_array(subspace, n=n, fn=fn) for subspace in space.spaces) - - -@create_empty_array.register(Dict) -def _create_empty_array_dict(space, n=1, fn=np.zeros): - return OrderedDict( - [ - (key, create_empty_array(subspace, n=n, fn=fn)) - for (key, subspace) in space.spaces.items() - ] - ) - - -@create_empty_array.register(Space) -def _create_empty_array_custom(space, n=1, fn=np.zeros): - return None diff --git a/gymnasium/vector/utils/shared_memory.py b/gymnasium/vector/utils/shared_memory.py index ea31e0d80..6f4d9c8a1 100644 --- a/gymnasium/vector/utils/shared_memory.py +++ b/gymnasium/vector/utils/shared_memory.py @@ -1,9 +1,11 @@ """Utility functions for vector environments to share memory between processes.""" +from __future__ import annotations + import multiprocessing as mp from collections import OrderedDict from ctypes import c_bool from functools import singledispatch -from typing import Union +from typing import Any import numpy as np @@ -12,10 +14,14 @@ from gymnasium.spaces import ( Box, Dict, Discrete, + Graph, MultiBinary, MultiDiscrete, + Sequence, Space, + Text, Tuple, + flatten, ) @@ -24,8 +30,8 @@ __all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_m @singledispatch def create_shared_memory( - space: Space, n: int = 1, ctx=mp -) -> Union[dict, tuple, mp.Array]: + space: Space[Any], n: int = 1, ctx=mp +) -> dict[str, Any] | tuple[Any, ...] | mp.Array: """Create a shared memory object, to be shared across processes. This eventually contains the observations from the vectorized environment. @@ -41,20 +47,24 @@ def create_shared_memory( Raises: CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance """ - raise CustomSpaceError( - "Cannot create a shared memory for space with " - f"type `{type(space)}`. Shared memory only supports " - "default Gymnasium spaces (e.g. `Box`, `Tuple`, " - "`Dict`, etc...), and does not support custom " - "Gymnasium spaces." - ) + if isinstance(space, Space): + raise CustomSpaceError( + f"Space of type `{type(space)}` doesn't have an registered `create_shared_memory` function. Register `{type(space)}` for `create_shared_memory` to support it." + ) + else: + raise TypeError( + f"The space provided to `create_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}" + ) @create_shared_memory.register(Box) @create_shared_memory.register(Discrete) @create_shared_memory.register(MultiDiscrete) @create_shared_memory.register(MultiBinary) -def _create_base_shared_memory(space, n: int = 1, ctx=mp): +def _create_base_shared_memory( + space: Box | Discrete | MultiDiscrete | MultiBinary, n: int = 1, ctx=mp +): + assert space.dtype is not None dtype = space.dtype.char if dtype in "?": dtype = c_bool @@ -62,14 +72,14 @@ def _create_base_shared_memory(space, n: int = 1, ctx=mp): @create_shared_memory.register(Tuple) -def _create_tuple_shared_memory(space, n: int = 1, ctx=mp): +def _create_tuple_shared_memory(space: Tuple, n: int = 1, ctx=mp): return tuple( create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces ) @create_shared_memory.register(Dict) -def _create_dict_shared_memory(space, n=1, ctx=mp): +def _create_dict_shared_memory(space: Dict, n: int = 1, ctx=mp): return OrderedDict( [ (key, create_shared_memory(subspace, n=n, ctx=ctx)) @@ -78,10 +88,23 @@ def _create_dict_shared_memory(space, n=1, ctx=mp): ) +@create_shared_memory.register(Text) +def _create_text_shared_memory(space: Text, n: int = 1, ctx=mp): + return ctx.Array(np.dtype(np.int32).char, n * space.max_length) + + +@create_shared_memory.register(Graph) +@create_shared_memory.register(Sequence) +def _create_dynamic_shared_memory(space: Graph | Sequence, n: int = 1, ctx=mp): + raise TypeError( + f"As {space} has a dynamic shape then it is not possible to make a static shared memory." + ) + + @singledispatch def read_from_shared_memory( - space: Space, shared_memory: Union[dict, tuple, mp.Array], n: int = 1 -) -> Union[dict, tuple, np.ndarray]: + space: Space, shared_memory: dict | tuple | mp.Array, n: int = 1 +) -> dict[str, Any] | tuple[Any, ...] | np.ndarray: """Read the batch of observations from shared memory as a numpy array. ..notes:: @@ -101,27 +124,30 @@ def read_from_shared_memory( Raises: CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance """ - raise CustomSpaceError( - "Cannot read from a shared memory for space with " - f"type `{type(space)}`. Shared memory only supports " - "default Gymnasium spaces (e.g. `Box`, `Tuple`, " - "`Dict`, etc...), and does not support custom " - "Gymnasium spaces." - ) + if isinstance(space, Space): + raise CustomSpaceError( + f"Space of type `{type(space)}` doesn't have an registered `read_from_shared_memory` function. Register `{type(space)}` for `read_from_shared_memory` to support it." + ) + else: + raise TypeError( + f"The space provided to `read_from_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}" + ) @read_from_shared_memory.register(Box) @read_from_shared_memory.register(Discrete) @read_from_shared_memory.register(MultiDiscrete) @read_from_shared_memory.register(MultiBinary) -def _read_base_from_shared_memory(space, shared_memory, n: int = 1): +def _read_base_from_shared_memory( + space: Box | Discrete | MultiDiscrete | MultiBinary, shared_memory, n: int = 1 +): return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape( (n,) + space.shape ) @read_from_shared_memory.register(Tuple) -def _read_tuple_from_shared_memory(space, shared_memory, n: int = 1): +def _read_tuple_from_shared_memory(space: Tuple, shared_memory, n: int = 1): return tuple( read_from_shared_memory(subspace, memory, n=n) for (memory, subspace) in zip(shared_memory, space.spaces) @@ -129,7 +155,7 @@ def _read_tuple_from_shared_memory(space, shared_memory, n: int = 1): @read_from_shared_memory.register(Dict) -def _read_dict_from_shared_memory(space, shared_memory, n: int = 1): +def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1): return OrderedDict( [ (key, read_from_shared_memory(subspace, shared_memory[key], n=n)) @@ -138,12 +164,30 @@ def _read_dict_from_shared_memory(space, shared_memory, n: int = 1): ) +@read_from_shared_memory.register(Text) +def _read_text_from_shared_memory(space: Text, shared_memory, n: int = 1) -> tuple[str]: + data = np.frombuffer(shared_memory.get_obj(), dtype=np.int32).reshape( + (n, space.max_length) + ) + + return tuple( + "".join( + [ + space.character_list[val] + for val in values + if val < len(space.character_set) + ] + ) + for values in data + ) + + @singledispatch def write_to_shared_memory( space: Space, index: int, value: np.ndarray, - shared_memory: Union[dict, tuple, mp.Array], + shared_memory: dict[str, Any] | tuple[Any, ...] | mp.Array, ): """Write the observation of a single environment into shared memory. @@ -157,20 +201,26 @@ def write_to_shared_memory( Raises: CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance """ - raise CustomSpaceError( - "Cannot write to a shared memory for space with " - f"type `{type(space)}`. Shared memory only supports " - "default Gymnasium spaces (e.g. `Box`, `Tuple`, " - "`Dict`, etc...), and does not support custom " - "Gymnasium spaces." - ) + if isinstance(space, Space): + raise CustomSpaceError( + f"Space of type `{type(space)}` doesn't have an registered `write_to_shared_memory` function. Register `{type(space)}` for `write_to_shared_memory` to support it." + ) + else: + raise TypeError( + f"The space provided to `write_to_shared_memory` is not a gymnasium Space instance, type: {type(space)}, {space}" + ) @write_to_shared_memory.register(Box) @write_to_shared_memory.register(Discrete) @write_to_shared_memory.register(MultiDiscrete) @write_to_shared_memory.register(MultiBinary) -def _write_base_to_shared_memory(space, index, value, shared_memory): +def _write_base_to_shared_memory( + space: Box | Discrete | MultiDiscrete | MultiBinary, + index: int, + value, + shared_memory, +): size = int(np.prod(space.shape)) destination = np.frombuffer(shared_memory.get_obj(), dtype=space.dtype) np.copyto( @@ -180,12 +230,26 @@ def _write_base_to_shared_memory(space, index, value, shared_memory): @write_to_shared_memory.register(Tuple) -def _write_tuple_to_shared_memory(space, index, values, shared_memory): +def _write_tuple_to_shared_memory( + space: Tuple, index: int, values: tuple[Any, ...], shared_memory +): for value, memory, subspace in zip(values, shared_memory, space.spaces): write_to_shared_memory(subspace, index, value, memory) @write_to_shared_memory.register(Dict) -def _write_dict_to_shared_memory(space, index, values, shared_memory): +def _write_dict_to_shared_memory( + space: Dict, index: int, values: dict[str, Any], shared_memory +): for key, subspace in space.spaces.items(): write_to_shared_memory(subspace, index, values[key], shared_memory[key]) + + +@write_to_shared_memory.register(Text) +def _write_text_to_shared_memory(space: Text, index: int, values: str, shared_memory): + size = space.max_length + destination = np.frombuffer(shared_memory.get_obj(), dtype=np.int32) + np.copyto( + destination[index * size : (index + 1) * size], + flatten(space, values), + ) diff --git a/gymnasium/experimental/vector/utils/space_utils.py b/gymnasium/vector/utils/space_utils.py similarity index 98% rename from gymnasium/experimental/vector/utils/space_utils.py rename to gymnasium/vector/utils/space_utils.py index a1b3ddff7..3d103ecdf 100644 --- a/gymnasium/experimental/vector/utils/space_utils.py +++ b/gymnasium/vector/utils/space_utils.py @@ -150,7 +150,7 @@ def iterate(space: Space[T_cov], items: Iterable[T_cov]) -> Iterator: The output object. This object is a (possibly nested) numpy array. Raises: - ValueError: Space is not an instance of :class:`gym.Space` + ValueError: Space is not an instance of :class:`gymnasium.Space` Example: >>> from gymnasium.spaces import Box, Dict @@ -311,14 +311,14 @@ def create_empty_array( Args: space: Observation space of a single environment in the vectorized environment. - n: Number of environments in the vectorized environment. If `None`, creates an empty sample from `space`. - fn: Function to apply when creating the empty numpy array. Examples of such functions are `np.empty` or `np.zeros`. + n: Number of environments in the vectorized environment. If ``None``, creates an empty sample from ``space``. + fn: Function to apply when creating the empty numpy array. Examples of such functions are ``np.empty`` or ``np.zeros``. Returns: The output object. This object is a (possibly nested) numpy array. Raises: - ValueError: Space is not a valid :class:`gym.Space` instance + ValueError: Space is not a valid :class:`gymnasium.Space` instance Example: >>> from gymnasium.spaces import Box, Dict diff --git a/gymnasium/vector/utils/spaces.py b/gymnasium/vector/utils/spaces.py deleted file mode 100644 index a05209de5..000000000 --- a/gymnasium/vector/utils/spaces.py +++ /dev/null @@ -1,215 +0,0 @@ -"""Utility functions for gymnasium spaces: batch space and iterator.""" -from collections import OrderedDict -from copy import deepcopy -from functools import singledispatch -from typing import Iterator - -import numpy as np - -from gymnasium.error import CustomSpaceError -from gymnasium.spaces import ( - Box, - Dict, - Discrete, - MultiBinary, - MultiDiscrete, - Space, - Tuple, -) - - -BaseGymSpaces = (Box, Discrete, MultiDiscrete, MultiBinary) -_BaseGymSpaces = BaseGymSpaces -__all__ = ["BaseGymSpaces", "_BaseGymSpaces", "batch_space", "iterate"] - - -@singledispatch -def batch_space(space: Space, n: int = 1) -> Space: - """Create a (batched) space, containing multiple copies of a single space. - - Args: - space: Space (e.g. the observation space) for a single environment in the vectorized environment. - n: Number of environments in the vectorized environment. - - Returns: - Space (e.g. the observation space) for a batch of environments in the vectorized environment. - - Raises: - ValueError: Cannot batch space that is not a valid :class:`gym.Space` instance - - Example: - >>> from gymnasium.spaces import Box, Dict - >>> import numpy as np - >>> space = Dict({ - ... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32), - ... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32) - ... }) - >>> batch_space(space, n=5) - Dict('position': Box(0.0, 1.0, (5, 3), float32), 'velocity': Box(0.0, 1.0, (5, 2), float32)) - """ - raise ValueError( - f"Cannot batch space with type `{type(space)}`. The space must be a valid `gymnasium.Space` instance." - ) - - -@batch_space.register(Box) -def _batch_space_box(space, n=1): - repeats = tuple([n] + [1] * space.low.ndim) - low, high = np.tile(space.low, repeats), np.tile(space.high, repeats) - return Box(low=low, high=high, dtype=space.dtype, seed=deepcopy(space.np_random)) - - -@batch_space.register(Discrete) -def _batch_space_discrete(space, n=1): - return MultiDiscrete( - np.full((n,), space.n, dtype=space.dtype), - dtype=space.dtype, - seed=deepcopy(space.np_random), - start=np.full((n,), space.start, dtype=space.dtype), - ) - - -@batch_space.register(MultiDiscrete) -def _batch_space_multidiscrete(space, n=1): - repeats = tuple([n] + [1] * space.nvec.ndim) - low = np.tile(space.start, repeats) - high = low + np.tile(space.nvec, repeats) - 1 - return Box( - low=low, - high=high, - dtype=space.dtype, - seed=deepcopy(space.np_random), - ) - - -@batch_space.register(MultiBinary) -def _batch_space_multibinary(space, n=1): - return Box( - low=0, - high=1, - shape=(n,) + space.shape, - dtype=space.dtype, - seed=deepcopy(space.np_random), - ) - - -@batch_space.register(Tuple) -def _batch_space_tuple(space, n=1): - return Tuple( - tuple(batch_space(subspace, n=n) for subspace in space.spaces), - seed=deepcopy(space.np_random), - ) - - -@batch_space.register(Dict) -def _batch_space_dict(space, n=1): - return Dict( - OrderedDict( - [ - (key, batch_space(subspace, n=n)) - for (key, subspace) in space.spaces.items() - ] - ), - seed=deepcopy(space.np_random), - ) - - -@batch_space.register(Space) -def _batch_space_custom(space, n=1): - # Without deepcopy, then the space.np_random is batched_space.spaces[0].np_random - # Which is an issue if you are sampling actions of both the original space and the batched space - batched_space = Tuple( - tuple(deepcopy(space) for _ in range(n)), seed=deepcopy(space.np_random) - ) - new_seeds = list(map(int, batched_space.np_random.integers(0, 1e8, n))) - batched_space.seed(new_seeds) - return batched_space - - -@singledispatch -def iterate(space: Space, items) -> Iterator: - """Iterate over the elements of a (batched) space. - - Args: - space: Space to which `items` belong to. - items: Items to be iterated over. - - Returns: - Iterator over the elements in `items`. - - Raises: - ValueError: Space is not an instance of :class:`gym.Space` - - Example: - >>> from gymnasium.spaces import Box, Dict - >>> import numpy as np - >>> space = Dict({ - ... 'position': Box(low=0, high=1, shape=(2, 3), seed=42, dtype=np.float32), - ... 'velocity': Box(low=0, high=1, shape=(2, 2), seed=42, dtype=np.float32)}) - >>> items = space.sample() - >>> it = iterate(space, items) - >>> next(it) - OrderedDict([('position', array([0.77395606, 0.43887845, 0.85859793], dtype=float32)), ('velocity', array([0.77395606, 0.43887845], dtype=float32))]) - >>> next(it) - OrderedDict([('position', array([0.697368 , 0.09417735, 0.97562236], dtype=float32)), ('velocity', array([0.85859793, 0.697368 ], dtype=float32))]) - >>> next(it) - Traceback (most recent call last): - ... - StopIteration - """ - raise ValueError( - f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance." - ) - - -@iterate.register(Discrete) -def _iterate_discrete(space, items): - raise TypeError("Unable to iterate over a space of type `Discrete`.") - - -@iterate.register(Box) -@iterate.register(MultiDiscrete) -@iterate.register(MultiBinary) -def _iterate_base(space, items): - try: - return iter(items) - except TypeError as e: - raise TypeError( - f"Unable to iterate over the following elements: {items}" - ) from e - - -@iterate.register(Tuple) -def _iterate_tuple(space, items): - # If this is a tuple of custom subspaces only, then simply iterate over items - if all( - isinstance(subspace, Space) - and (not isinstance(subspace, BaseGymSpaces + (Tuple, Dict))) - for subspace in space.spaces - ): - return iter(items) - - return zip( - *[iterate(subspace, items[i]) for i, subspace in enumerate(space.spaces)] - ) - - -@iterate.register(Dict) -def _iterate_dict(space, items): - keys, values = zip( - *[ - (key, iterate(subspace, items[key])) - for key, subspace in space.spaces.items() - ] - ) - for item in zip(*values): - yield OrderedDict([(key, value) for (key, value) in zip(keys, item)]) - - -@iterate.register(Space) -def _iterate_custom(space, items): - raise CustomSpaceError( - f"Unable to iterate over {items}, since {space} " - "is a custom `gymnasium.Space` instance (i.e. not one of " - "`Box`, `Dict`, etc...)." - ) diff --git a/gymnasium/vector/vector_env.py b/gymnasium/vector/vector_env.py index 70b98858b..6150977d1 100644 --- a/gymnasium/vector/vector_env.py +++ b/gymnasium/vector/vector_env.py @@ -1,30 +1,46 @@ """Base class for vectorized environments.""" -from typing import Any, List, Optional, Tuple, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Generic, TypeVar import numpy as np -from numpy.typing import NDArray import gymnasium as gym -from gymnasium.vector.utils.spaces import batch_space +from gymnasium.core import ActType, ObsType, RenderFrame +from gymnasium.utils import seeding -__all__ = ["VectorEnv"] +if TYPE_CHECKING: + from gymnasium.envs.registration import EnvSpec + +ArrayType = TypeVar("ArrayType") -class VectorEnv(gym.Env): +__all__ = [ + "VectorEnv", + "VectorWrapper", + "VectorObservationWrapper", + "VectorActionWrapper", + "VectorRewardWrapper", + "ArrayType", +] + + +class VectorEnv(Generic[ObsType, ActType, ArrayType]): """Base class for vectorized environments to run multiple independent copies of the same environment in parallel. Vector environments can provide a linear speed-up in the steps taken per second through sampling multiple sub-environments at the same time. To prevent terminated environments waiting until all sub-environments have - terminated or truncated, the vector environments autoreset sub-environments after they terminate or truncated. - As a result, the final step's observation and info are overwritten by the reset's observation and info. - Therefore, the observation and info for the final step of a sub-environment is stored in the info parameter, + terminated or truncated, the vector environments automatically reset sub-environments after they terminate or truncated (within the same step call). + As a result, the step's observation and info are overwritten by the reset's observation and info. + To preserve this data, the observation and info for the final step of a sub-environment is stored in the info parameter, using `"final_observation"` and `"final_info"` respectively. See :meth:`step` for more information. - The vector environments batch `observations`, `rewards`, `terminations`, `truncations` and `info` for each - parallel environment. In addition, :meth:`step` expects to receive a batch of actions for each parallel environment. + The vector environments batches `observations`, `rewards`, `terminations`, `truncations` and `info` for each + sub-environment. In addition, :meth:`step` expects to receive a batch of actions for each parallel environment. - Gymnasium contains two types of Vector environments: :class:`AsyncVectorEnv` and :class:`SyncVectorEnv`. + Gymnasium contains two generalised Vector environments: :class:`AsyncVectorEnv` and :class:`SyncVectorEnv` along with + several custom vector environment implementations. The Vector Environments have the additional attributes for users to understand the implementation @@ -34,89 +50,67 @@ class VectorEnv(gym.Env): - :attr:`action_space` - The batched action space of the vector environment - :attr:`single_action_space` - The action space of a single sub-environment - Note: - The info parameter of :meth:`reset` and :meth:`step` was originally implemented before OpenAI Gym v25 was a list - of dictionary for each sub-environment. However, this was modified in OpenAI Gym v25+ and in Gymnasium to a - dictionary with a NumPy array for each key. To use the old info style using the :class:`VectorListInfo`. + Examples: + >>> import gymnasium as gym + >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync", wrappers=(gym.wrappers.TimeAwareObservation,)) + >>> envs = gym.wrappers.vector.ClipReward(envs, min_reward=0.2, max_reward=0.8) + >>> envs + + >>> observations, infos = envs.reset(seed=123) + >>> observations + array([[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282, 0. ], + [ 0.02852531, 0.02858594, 0.0469136 , 0.02480598, 0. ], + [ 0.03517495, -0.000635 , -0.01098382, -0.03203924, 0. ]]) + >>> infos + {} + >>> _ = envs.action_space.seed(123) + >>> observations, rewards, terminations, truncations, infos = envs.step(envs.action_space.sample()) + >>> observations + array([[ 0.01734283, 0.15089367, -0.02859527, -0.33293587, 1. ], + [ 0.02909703, -0.16717631, 0.04740972, 0.3319138 , 1. ], + [ 0.03516225, -0.19559774, -0.01162461, 0.25715804, 1. ]]) + >>> rewards + array([0.8, 0.8, 0.8]) + >>> terminations + array([False, False, False]) + >>> truncations + array([False, False, False]) + >>> infos + {} + >>> envs.close() Note: - To render the sub-environments, use :meth:`call` with "render" arguments. Remember to set the `render_modes` - for all the sub-environments during initialization. + The info parameter of :meth:`reset` and :meth:`step` was originally implemented before v0.25 as a list + of dictionary for each sub-environment. However, this was modified in v0.25+ to be a + dictionary with a NumPy array for each key. To use the old info style, utilise the :class:`DictInfoToList` wrapper. Note: All parallel environments should share the identical observation and action spaces. In other words, a vector of multiple different environments is not supported. + + Note: + :func:`make_vec` is the equivalent function to :func:`make` for vector environments. """ - def __init__( - self, - num_envs: int, - observation_space: gym.Space, - action_space: gym.Space, - ): - """Base class for vectorized environments. + spec: EnvSpec | None = None + render_mode: str | None = None + closed: bool = False - Args: - num_envs: Number of environments in the vectorized environment. - observation_space: Observation space of a single environment. - action_space: Action space of a single environment. - """ - self.num_envs = num_envs - self.is_vector_env = True - self.observation_space = batch_space(observation_space, n=num_envs) - self.action_space = batch_space(action_space, n=num_envs) + observation_space: gym.Space + action_space: gym.Space + single_observation_space: gym.Space + single_action_space: gym.Space - self.closed = False - self.viewer = None + num_envs: int - # The observation and action spaces of a single environment are - # kept in separate properties - self.single_observation_space = observation_space - self.single_action_space = action_space - - def reset_async( - self, - seed: Optional[Union[int, List[int]]] = None, - options: Optional[dict] = None, - ): - """Reset the sub-environments asynchronously. - - This method will return ``None``. A call to :meth:`reset_async` should be followed - by a call to :meth:`reset_wait` to retrieve the results. - - Args: - seed: The reset seed - options: Reset options - """ - pass - - def reset_wait( - self, - seed: Optional[Union[int, List[int]]] = None, - options: Optional[dict] = None, - ): - """Retrieves the results of a :meth:`reset_async` call. - - A call to this method must always be preceded by a call to :meth:`reset_async`. - - Args: - seed: The reset seed - options: Reset options - - Returns: - The results from :meth:`reset_async` - - Raises: - NotImplementedError: VectorEnv does not implement function - """ - raise NotImplementedError("VectorEnv does not implement function") + _np_random: np.random.Generator | None = None def reset( self, *, - seed: Optional[Union[int, List[int]]] = None, - options: Optional[dict] = None, - ): + seed: int | list[int] | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[ObsType, dict[str, Any]]: # type: ignore """Reset all parallel environments and return a batch of initial observations and info. Args: @@ -128,47 +122,26 @@ class VectorEnv(gym.Env): Example: >>> import gymnasium as gym - >>> envs = gym.vector.make("CartPole-v1", num_envs=3) - >>> envs.reset(seed=42) - (array([[ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], + >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync") + >>> observations, infos = envs.reset(seed=42) + >>> observations + array([[ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], [ 0.01522993, -0.04562247, -0.04799704, 0.03392126], [-0.03774345, -0.02418869, -0.00942293, 0.0469184 ]], - dtype=float32), {}) + dtype=float32) + >>> infos + {} """ - self.reset_async(seed=seed, options=options) - return self.reset_wait(seed=seed, options=options) - - def step_async(self, actions): - """Asynchronously performs steps in the sub-environments. - - The results can be retrieved via a call to :meth:`step_wait`. - - Args: - actions: The actions to take asynchronously - """ - - def step_wait( - self, **kwargs - ) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]: - """Retrieves the results of a :meth:`step_async` call. - - A call to this method must always be preceded by a call to :meth:`step_async`. - - Args: - **kwargs: Additional keywords for vector implementation - - Returns: - The results from the :meth:`step_async` call - """ - raise NotImplementedError() + if seed is not None: + self._np_random, seed = seeding.np_random(seed) def step( - self, actions - ) -> Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], dict]: + self, actions: ActType + ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]: """Take an action for each parallel environment. Args: - actions: element of :attr:`action_space` Batch of actions. + actions: Batch of actions with the :attr:`action_space` shape. Returns: Batch of (observations, rewards, terminations, truncations, infos) @@ -181,10 +154,10 @@ class VectorEnv(gym.Env): Example: >>> import gymnasium as gym >>> import numpy as np - >>> envs = gym.vector.make("CartPole-v1", num_envs=3) + >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync") >>> _ = envs.reset(seed=42) - >>> actions = np.array([1, 0, 1]) - >>> observations, rewards, termination, truncation, infos = envs.step(actions) + >>> actions = np.array([1, 0, 1], dtype=np.int32) + >>> observations, rewards, terminations, truncations, infos = envs.step(actions) >>> observations array([[ 0.02727336, 0.18847767, 0.03625453, -0.26141977], [ 0.01431748, -0.24002443, -0.04731862, 0.3110827 ], @@ -192,62 +165,25 @@ class VectorEnv(gym.Env): dtype=float32) >>> rewards array([1., 1., 1.]) - >>> termination + >>> terminations array([False, False, False]) - >>> truncation + >>> terminations array([False, False, False]) >>> infos {} """ - self.step_async(actions) - return self.step_wait() - def call_async(self, name, *args, **kwargs): - """Calls a method name for each parallel environment asynchronously.""" - - def call_wait(self, **kwargs) -> List[Any]: # type: ignore - """After calling a method in :meth:`call_async`, this function collects the results.""" - - def call(self, name: str, *args, **kwargs) -> List[Any]: - """Call a method, or get a property, from each parallel environment. - - Args: - name (str): Name of the method or property to call. - *args: Arguments to apply to the method call. - **kwargs: Keyword arguments to apply to the method call. + def render(self) -> tuple[RenderFrame, ...] | None: + """Returns the rendered frames from the parallel environments. Returns: - List of the results of the individual calls to the method or property for each environment. + A tuple of rendered frames from the parallel environments """ - self.call_async(name, *args, **kwargs) - return self.call_wait() + raise NotImplementedError( + f"{self.__str__()} render function is not implemented." + ) - def get_attr(self, name: str): - """Get a property from each parallel environment. - - Args: - name (str): Name of the property to be get from each individual environment. - - Returns: - The property with name - """ - return self.call(name) - - def set_attr(self, name: str, values: Union[list, tuple, object]): - """Set a property in each sub-environment. - - Args: - name (str): Name of the property to be set in each individual environment. - values (list, tuple, or object): Values of the property to be set to. If `values` is a list or - tuple, then it corresponds to the values for each individual environment, otherwise a single value - is set for all environments. - """ - - def close_extras(self, **kwargs): - """Clean up the extra resources e.g. beyond what's in this base class.""" - pass - - def close(self, **kwargs): + def close(self, **kwargs: Any): """Close all parallel environments and release resources. It also closes all the existing image viewers, then calls :meth:`close_extras` and set @@ -266,12 +202,37 @@ class VectorEnv(gym.Env): """ if self.closed: return - if self.viewer is not None: - self.viewer.close() + self.close_extras(**kwargs) self.closed = True - def _add_info(self, infos: dict, info: dict, env_num: int) -> dict: + def close_extras(self, **kwargs: Any): + """Clean up the extra resources e.g. beyond what's in this base class.""" + pass + + @property + def np_random(self) -> np.random.Generator: + """Returns the environment's internal :attr:`_np_random` that if not set will initialise with a random seed. + + Returns: + Instances of `np.random.Generator` + """ + if self._np_random is None: + self._np_random, seed = seeding.np_random() + return self._np_random + + @np_random.setter + def np_random(self, value: np.random.Generator): + self._np_random = value + + @property + def unwrapped(self): + """Return the base environment.""" + return self + + def _add_info( + self, infos: dict[str, Any], info: dict[str, Any], env_num: int + ) -> dict[str, Any]: """Add env info to the info dictionary of the vectorized environment. Given the `info` of a single environment add it to the `infos` dictionary @@ -298,7 +259,7 @@ class VectorEnv(gym.Env): infos[k], infos[f"_{k}"] = info_array, array_mask return infos - def _init_info_arrays(self, dtype: type) -> Tuple[np.ndarray, np.ndarray]: + def _init_info_arrays(self, dtype: type) -> tuple[np.ndarray, np.ndarray]: """Initialize the info array. Initialize the info array. If the dtype is numeric @@ -335,12 +296,14 @@ class VectorEnv(gym.Env): A string containing the class name, number of environments and environment spec id """ if self.spec is None: - return f"{self.__class__.__name__}({self.num_envs})" + return f"{self.__class__.__name__}(num_envs={self.num_envs})" else: - return f"{self.__class__.__name__}({self.spec.id}, {self.num_envs})" + return ( + f"{self.__class__.__name__}({self.spec.id}, num_envs={self.num_envs})" + ) -class VectorEnvWrapper(VectorEnv): +class VectorWrapper(VectorEnv): """Wraps the vectorized environment to allow a modular transformation. This class is the base class for all wrappers for vectorized environments. The subclass @@ -352,47 +315,223 @@ class VectorEnvWrapper(VectorEnv): """ def __init__(self, env: VectorEnv): - assert isinstance(env, VectorEnv) + """Initialize the vectorized environment wrapper. + + Args: + env: The environment to wrap + """ self.env = env + assert isinstance(env, VectorEnv) - # explicitly forward the methods defined in VectorEnv - # to self.env (instead of the base class) - def reset_async(self, **kwargs): - return self.env.reset_async(**kwargs) + self._observation_space: gym.Space | None = None + self._action_space: gym.Space | None = None + self._single_observation_space: gym.Space | None = None + self._single_action_space: gym.Space | None = None - def reset_wait(self, **kwargs): - return self.env.reset_wait(**kwargs) + def reset( + self, + *, + seed: int | list[int] | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[ObsType, dict[str, Any]]: + """Reset all environment using seed and options.""" + return self.env.reset(seed=seed, options=options) - def step_async(self, actions): - return self.env.step_async(actions) + def step( + self, actions: ActType + ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]: + """Step through all environments using the actions returning the batched data.""" + return self.env.step(actions) - def step_wait(self): - return self.env.step_wait() + def render(self) -> tuple[RenderFrame, ...] | None: + """Returns the render mode from the base vector environment.""" + return self.env.render() - def close(self, **kwargs): + def close(self, **kwargs: Any): + """Close all environments.""" return self.env.close(**kwargs) - def close_extras(self, **kwargs): + def close_extras(self, **kwargs: Any): + """Close all extra resources.""" return self.env.close_extras(**kwargs) - def call(self, name, *args, **kwargs): - return self.env.call(name, *args, **kwargs) - - def set_attr(self, name, values): - return self.env.set_attr(name, values) - - # implicitly forward all other methods and attributes to self.env - def __getattr__(self, name): - if name.startswith("_"): - raise AttributeError(f"attempted to get missing private attribute '{name}'") - return getattr(self.env, name) - @property def unwrapped(self): + """Return the base non-wrapped environment.""" return self.env.unwrapped def __repr__(self): + """Return the string representation of the vectorized environment.""" return f"<{self.__class__.__name__}, {self.env}>" - def __del__(self): - self.env.__del__() + @property + def spec(self) -> EnvSpec | None: + """Gets the specification of the wrapped environment.""" + return self.env.spec + + @property + def observation_space(self) -> gym.Space: + """Gets the observation space of the vector environment.""" + if self._observation_space is None: + return self.env.observation_space + return self._observation_space + + @observation_space.setter + def observation_space(self, space: gym.Space): + """Sets the observation space of the vector environment.""" + self._observation_space = space + + @property + def action_space(self) -> gym.Space: + """Gets the action space of the vector environment.""" + if self._action_space is None: + return self.env.action_space + return self._action_space + + @action_space.setter + def action_space(self, space: gym.Space): + """Sets the action space of the vector environment.""" + self._action_space = space + + @property + def single_observation_space(self) -> gym.Space: + """Gets the single observation space of the vector environment.""" + if self._single_observation_space is None: + return self.env.single_observation_space + return self._single_observation_space + + @single_observation_space.setter + def single_observation_space(self, space: gym.Space): + """Sets the single observation space of the vector environment.""" + self._single_observation_space = space + + @property + def single_action_space(self) -> gym.Space: + """Gets the single action space of the vector environment.""" + if self._single_action_space is None: + return self.env.single_action_space + return self._single_action_space + + @single_action_space.setter + def single_action_space(self, space): + """Sets the single action space of the vector environment.""" + self._single_action_space = space + + @property + def num_envs(self) -> int: + """Gets the wrapped vector environment's num of the sub-environments.""" + return self.env.num_envs + + @property + def render_mode(self) -> tuple[RenderFrame, ...] | None: + """Returns the `render_mode` from the base environment.""" + return self.env.render_mode + + +class VectorObservationWrapper(VectorWrapper): + """Wraps the vectorized environment to allow a modular transformation of the observation. + + Equivalent to :class:`gymnasium.ObservationWrapper` for vectorized environments. + """ + + def reset( + self, + *, + seed: int | list[int] | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[ObsType, dict[str, Any]]: + """Modifies the observation returned from the environment ``reset`` using the :meth:`observation`.""" + obs, info = self.env.reset(seed=seed, options=options) + return self.vector_observation(obs), info + + def step( + self, actions: ActType + ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]: + """Modifies the observation returned from the environment ``step`` using the :meth:`observation`.""" + observation, reward, termination, truncation, info = self.env.step(actions) + return ( + self.vector_observation(observation), + reward, + termination, + truncation, + self.update_final_obs(info), + ) + + def vector_observation(self, observation: ObsType) -> ObsType: + """Defines the vector observation transformation. + + Args: + observation: A vector observation from the environment + + Returns: + the transformed observation + """ + raise NotImplementedError + + def single_observation(self, observation: ObsType) -> ObsType: + """Defines the single observation transformation. + + Args: + observation: A single observation from the environment + + Returns: + The transformed observation + """ + raise NotImplementedError + + def update_final_obs(self, info: dict[str, Any]) -> dict[str, Any]: + """Updates the `final_obs` in the info using `single_observation`.""" + if "final_observation" in info: + for i, obs in enumerate(info["final_observation"]): + if obs is not None: + info["final_observation"][i] = self.single_observation(obs) + return info + + +class VectorActionWrapper(VectorWrapper): + """Wraps the vectorized environment to allow a modular transformation of the actions. + + Equivalent of :class:`gymnasium.ActionWrapper` for vectorized environments. + """ + + def step( + self, actions: ActType + ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]: + """Steps through the environment using a modified action by :meth:`action`.""" + return self.env.step(self.actions(actions)) + + def actions(self, actions: ActType) -> ActType: + """Transform the actions before sending them to the environment. + + Args: + actions (ActType): the actions to transform + + Returns: + ActType: the transformed actions + """ + raise NotImplementedError + + +class VectorRewardWrapper(VectorWrapper): + """Wraps the vectorized environment to allow a modular transformation of the reward. + + Equivalent of :class:`gymnasium.RewardWrapper` for vectorized environments. + """ + + def step( + self, actions: ActType + ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]: + """Steps through the environment returning a reward modified by :meth:`reward`.""" + observation, reward, termination, truncation, info = self.env.step(actions) + return observation, self.rewards(reward), termination, truncation, info + + def rewards(self, reward: ArrayType) -> ArrayType: + """Transform the reward before returning it. + + Args: + reward (array): the reward to transform + + Returns: + array: the transformed reward + """ + raise NotImplementedError diff --git a/gymnasium/wrappers/README.md b/gymnasium/wrappers/README.md deleted file mode 100644 index c14b0d308..000000000 --- a/gymnasium/wrappers/README.md +++ /dev/null @@ -1,18 +0,0 @@ -# Wrappers - -Wrappers are used to transform an environment in a modular way: - -```python -import gymnasium as gym -env = gym.make('CartPole-v1') -env = MyWrapper(env) -``` - -## Quick tips for writing your own wrapper - -- Don't forget to call `super(class_name, self).__init__(env)` if you override the wrapper's `__init__` function -- You can access the inner environment with `self.unwrapped` -- You can access the previous wrapper using `self.env` -- The variables `metadata`, `action_space`, `observation_space`, `reward_range`, and `spec` are copied to `self` from the previous layer -- Create a wrapped function for at least one of the following: `__init__(self, env)`, `step`, `reset`, `render`, `close`, or `seed` -- Your layered function should take its input from the previous layer (`self.env`) and/or the inner layer (`self.unwrapped`) diff --git a/gymnasium/wrappers/__init__.py b/gymnasium/wrappers/__init__.py index fb9c66714..d5563e010 100644 --- a/gymnasium/wrappers/__init__.py +++ b/gymnasium/wrappers/__init__.py @@ -1,9 +1,8 @@ -"""Module of wrapper classes. +"""Wrappers are a convenient way to modify an existing environment without having to alter the underlying code directly. -Wrappers are a convenient way to modify an existing environment without having to alter the underlying code directly. -Using wrappers will allow you to avoid a lot of boilerplate code and make your environment more modular. Wrappers can -also be chained to combine their effects. -Most environments that are generated via :meth:`gymnasium.make` will already be wrapped by default. +Using wrappers will allow you to avoid a lot of boilerplate code and make your environment more modular. +Importantly wrappers can be chained to combine their effects and most environments that are generated via +:meth:`gymnasium.make` will already be wrapped by default. In order to wrap an environment, you must first initialize a base environment. Then you can pass this environment along with (possibly optional) parameters to the wrapper's constructor. @@ -46,27 +45,135 @@ If you need a wrapper to do more complicated tasks, you can inherit from the :cl If you'd like to implement your own custom wrapper, check out `the corresponding tutorial <../../tutorials/implementing_custom_wrappers>`_. """ +# pyright: reportUnsupportedDunderAll=false +import importlib +import re + +from gymnasium.error import DeprecatedWrapper +from gymnasium.wrappers import vector from gymnasium.wrappers.atari_preprocessing import AtariPreprocessing -from gymnasium.wrappers.autoreset import AutoResetWrapper -from gymnasium.wrappers.clip_action import ClipAction -from gymnasium.wrappers.compatibility import EnvCompatibility -from gymnasium.wrappers.env_checker import PassiveEnvChecker -from gymnasium.wrappers.filter_observation import FilterObservation -from gymnasium.wrappers.flatten_observation import FlattenObservation -from gymnasium.wrappers.frame_stack import FrameStack, LazyFrames -from gymnasium.wrappers.gray_scale_observation import GrayScaleObservation -from gymnasium.wrappers.human_rendering import HumanRendering -from gymnasium.wrappers.normalize import NormalizeObservation, NormalizeReward -from gymnasium.wrappers.order_enforcing import OrderEnforcing -from gymnasium.wrappers.pixel_observation import PixelObservationWrapper -from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistics -from gymnasium.wrappers.record_video import RecordVideo, capped_cubic_video_schedule -from gymnasium.wrappers.render_collection import RenderCollection -from gymnasium.wrappers.rescale_action import RescaleAction -from gymnasium.wrappers.resize_observation import ResizeObservation -from gymnasium.wrappers.step_api_compatibility import StepAPICompatibility -from gymnasium.wrappers.time_aware_observation import TimeAwareObservation -from gymnasium.wrappers.time_limit import TimeLimit -from gymnasium.wrappers.transform_observation import TransformObservation -from gymnasium.wrappers.transform_reward import TransformReward -from gymnasium.wrappers.vector_list_info import VectorListInfo +from gymnasium.wrappers.common import ( + Autoreset, + OrderEnforcing, + PassiveEnvChecker, + RecordEpisodeStatistics, + TimeLimit, +) +from gymnasium.wrappers.rendering import HumanRendering, RecordVideo, RenderCollection +from gymnasium.wrappers.stateful_action import StickyAction +from gymnasium.wrappers.stateful_observation import ( + DelayObservation, + FrameStackObservation, + MaxAndSkipObservation, + NormalizeObservation, + TimeAwareObservation, +) +from gymnasium.wrappers.stateful_reward import NormalizeReward +from gymnasium.wrappers.transform_action import ( + ClipAction, + RescaleAction, + TransformAction, +) +from gymnasium.wrappers.transform_observation import ( + DtypeObservation, + FilterObservation, + FlattenObservation, + GrayscaleObservation, + RenderObservation, + RescaleObservation, + ReshapeObservation, + ResizeObservation, + TransformObservation, +) +from gymnasium.wrappers.transform_reward import ClipReward, TransformReward + + +__all__ = [ + "vector", + # --- Observation wrappers --- + "AtariPreprocessing", + "DelayObservation", + "DtypeObservation", + "FilterObservation", + "FlattenObservation", + "FrameStackObservation", + "GrayscaleObservation", + "TransformObservation", + "MaxAndSkipObservation", + "NormalizeObservation", + "RenderObservation", + "ResizeObservation", + "ReshapeObservation", + "RescaleObservation", + "TimeAwareObservation", + # --- Action Wrappers --- + "ClipAction", + "TransformAction", + "RescaleAction", + # "NanAction", + "StickyAction", + # --- Reward wrappers --- + "ClipReward", + "TransformReward", + "NormalizeReward", + # --- Common --- + "TimeLimit", + "Autoreset", + "PassiveEnvChecker", + "OrderEnforcing", + "RecordEpisodeStatistics", + # --- Rendering --- + "RenderCollection", + "RecordVideo", + "HumanRendering", + # --- Conversion --- + "JaxToNumpy", + "JaxToTorch", + "NumpyToTorch", +] + +# As these wrappers requires `jax` or `torch`, they are loaded by runtime for users trying to access them +# to avoid `import jax` or `import torch` on `import gymnasium`. +_wrapper_to_class = { + # data converters + "JaxToNumpy": "jax_to_numpy", + "JaxToTorch": "jax_to_torch", + "NumpyToTorch": "numpy_to_torch", +} + +_renamed_wrapper = { + "AutoResetWrapper": "Autoreset", + "FrameStack": "FrameStackObservation", + "PixelObservationWrapper": "RenderObservation", + "VectorListInfo": "vector.DictInfoToList", +} + + +def __getattr__(wrapper_name: str): + """Load a wrapper by name. + + This optimizes the loading of gymnasium wrappers by only loading the wrapper if it is used. + Errors will be raised if the wrapper does not exist or if the version is not the latest. + + Args: + wrapper_name: The name of a wrapper to load. + + Returns: + The specified wrapper. + + Raises: + AttributeError: If the wrapper does not exist. + DeprecatedWrapper: If the version is not the latest. + """ + # Check if the requested wrapper is in the _wrapper_to_class dictionary + if wrapper_name in _wrapper_to_class: + import_stmt = f"gymnasium.wrappers.{_wrapper_to_class[wrapper_name]}" + module = importlib.import_module(import_stmt) + return getattr(module, wrapper_name) + + elif wrapper_name in _renamed_wrapper: + raise AttributeError( + f"{wrapper_name!r} has been renamed with `wrappers.{_renamed_wrapper[wrapper_name]}`" + ) + + raise AttributeError(f"module {__name__!r} has no attribute {wrapper_name!r}") diff --git a/gymnasium/wrappers/atari_preprocessing.py b/gymnasium/wrappers/atari_preprocessing.py index 019ab9ab4..9caefecc9 100644 --- a/gymnasium/wrappers/atari_preprocessing.py +++ b/gymnasium/wrappers/atari_preprocessing.py @@ -5,14 +5,14 @@ import gymnasium as gym from gymnasium.spaces import Box -try: - import cv2 -except ImportError: - cv2 = None +__all__ = ["AtariPreprocessing"] class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs): - """Atari 2600 preprocessing wrapper. + """Implements the common preprocessing techniques for Atari environments (excluding frame stacking). + + For frame stacking use :class:`gymnasium.wrappers.FrameStackObservation`. + No vector version of the wrapper exists This class follows the guidelines in Machado et al. (2018), "Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents". @@ -20,13 +20,22 @@ class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs): Specifically, the following preprocess stages applies to the atari environment: - Noop Reset: Obtains the initial state by taking a random number of no-ops on reset, default max 30 no-ops. - - Frame skipping: The number of frames skipped between steps, 4 by default - - Max-pooling: Pools over the most recent two observations from the frame skips + - Frame skipping: The number of frames skipped between steps, 4 by default. + - Max-pooling: Pools over the most recent two observations from the frame skips. - Termination signal when a life is lost: When the agent losses a life during the environment, then the environment is terminated. Turned off by default. Not recommended by Machado et al. (2018). - - Resize to a square image: Resizes the atari environment original observation shape from 210x180 to 84x84 by default - - Grayscale observation: If the observation is colour or greyscale, by default, greyscale. - - Scale observation: If to scale the observation between [0, 1) or [0, 255), by default, not scaled. + - Resize to a square image: Resizes the atari environment original observation shape from 210x180 to 84x84 by default. + - Grayscale observation: Makes the observation greyscale, enabled by default. + - Grayscale new axis: Extends the last channel of the observation such that the image is 3-dimensional, not enabled by default. + - Scale observation: Whether to scale the observation between [0, 1) or [0, 255), not scaled by default. + + Example: + >>> import gymnasium as gym # doctest: +SKIP + >>> env = gym.make("ALE/Adventure-v5") # doctest: +SKIP + >>> env = AtariPreprocessing(env, noop_max=10, frame_skip=0, screen_size=84, terminal_on_life_loss=True, grayscale_obs=False, grayscale_newaxis=False) # doctest: +SKIP + + Change logs: + * Added in gym v0.12.2 (gym #1455) """ def __init__( @@ -46,7 +55,7 @@ class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs): env (Env): The environment to apply the preprocessing noop_max (int): For No-op reset, the max number no-ops actions are taken at reset, to turn off, set to 0. frame_skip (int): The number of frames between new observation the agents observations effecting the frequency at which the agent experiences the game. - screen_size (int): resize Atari frame + screen_size (int): resize Atari frame. terminal_on_life_loss (bool): `if True`, then :meth:`step()` returns `terminated=True` whenever a life is lost. grayscale_obs (bool): if True, then gray scale observation is returned, otherwise, RGB observation @@ -72,10 +81,13 @@ class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs): ) gym.Wrapper.__init__(self, env) - if cv2 is None: + try: + import cv2 # noqa: F401 + except ImportError: raise gym.error.DependencyNotInstalled( "opencv-python package not installed, run `pip install gymnasium[other]` to get dependencies for atari" ) + assert frame_skip > 0 assert screen_size > 0 assert noop_max >= 0 @@ -187,7 +199,9 @@ class AtariPreprocessing(gym.Wrapper, gym.utils.RecordConstructorArgs): def _get_obs(self): if self.frame_skip > 1: # more efficient in-place pooling np.maximum(self.obs_buffer[0], self.obs_buffer[1], out=self.obs_buffer[0]) - assert cv2 is not None + + import cv2 + obs = cv2.resize( self.obs_buffer[0], (self.screen_size, self.screen_size), diff --git a/gymnasium/wrappers/autoreset.py b/gymnasium/wrappers/autoreset.py deleted file mode 100644 index e3e49f239..000000000 --- a/gymnasium/wrappers/autoreset.py +++ /dev/null @@ -1,86 +0,0 @@ -"""Wrapper that autoreset environments when `terminated=True` or `truncated=True`.""" -from __future__ import annotations - -from copy import deepcopy -from typing import TYPE_CHECKING - -import gymnasium as gym - - -if TYPE_CHECKING: - from gymnasium.envs.registration import EnvSpec - - -class AutoResetWrapper(gym.Wrapper, gym.utils.RecordConstructorArgs): - """A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`. - - When calling step causes :meth:`Env.step` to return `terminated=True` or `truncated=True`, :meth:`Env.reset` is called, - and the return format of :meth:`self.step` is as follows: ``(new_obs, final_reward, final_terminated, final_truncated, info)`` - with new step API and ``(new_obs, final_reward, final_done, info)`` with the old step API. - - - ``new_obs`` is the first observation after calling :meth:`self.env.reset` - - ``final_reward`` is the reward after calling :meth:`self.env.step`, prior to calling :meth:`self.env.reset`. - - ``final_terminated`` is the terminated value before calling :meth:`self.env.reset`. - - ``final_truncated`` is the truncated value before calling :meth:`self.env.reset`. Both `final_terminated` and `final_truncated` cannot be False. - - ``info`` is a dict containing all the keys from the info dict returned by the call to :meth:`self.env.reset`, - with an additional key "final_observation" containing the observation returned by the last call to :meth:`self.env.step` - and "final_info" containing the info dict returned by the last call to :meth:`self.env.step`. - - Warning: - When using this wrapper to collect rollouts, note that when :meth:`Env.step` returns `terminated` or `truncated`, a - new observation from after calling :meth:`Env.reset` is returned by :meth:`Env.step` alongside the - final reward, terminated and truncated state from the previous episode. - If you need the final state from the previous episode, you need to retrieve it via the - "final_observation" key in the info dict. - Make sure you know what you're doing if you use this wrapper! - """ - - def __init__(self, env: gym.Env): - """A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`. - - Args: - env (gym.Env): The environment to apply the wrapper - """ - gym.utils.RecordConstructorArgs.__init__(self) - gym.Wrapper.__init__(self, env) - - def step(self, action): - """Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered. - - Args: - action: The action to take - - Returns: - The autoreset environment :meth:`step` - """ - obs, reward, terminated, truncated, info = self.env.step(action) - if terminated or truncated: - new_obs, new_info = self.env.reset() - assert ( - "final_observation" not in new_info - ), 'info dict cannot contain key "final_observation" ' - assert ( - "final_info" not in new_info - ), 'info dict cannot contain key "final_info" ' - - new_info["final_observation"] = obs - new_info["final_info"] = info - - obs = new_obs - info = new_info - - return obs, reward, terminated, truncated, info - - @property - def spec(self) -> EnvSpec | None: - """Modifies the environment spec to specify the `autoreset=True`.""" - if self._cached_spec is not None: - return self._cached_spec - - env_spec = self.env.spec - if env_spec is not None: - env_spec = deepcopy(env_spec) - env_spec.autoreset = True - - self._cached_spec = env_spec - return env_spec diff --git a/gymnasium/wrappers/clip_action.py b/gymnasium/wrappers/clip_action.py deleted file mode 100644 index c11768107..000000000 --- a/gymnasium/wrappers/clip_action.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Wrapper for clipping actions within a valid bound.""" -import numpy as np - -import gymnasium as gym -from gymnasium.spaces import Box - - -class ClipAction(gym.ActionWrapper, gym.utils.RecordConstructorArgs): - """Clip the continuous action within the valid :class:`Box` observation space bound. - - Example: - >>> import gymnasium as gym - >>> from gymnasium.wrappers import ClipAction - >>> env = gym.make("Hopper-v4") - >>> env = ClipAction(env) - >>> env.action_space - Box(-1.0, 1.0, (3,), float32) - >>> _ = env.reset(seed=42) - >>> _ = env.step(np.array([5.0, -2.0, 0.0])) - ... # Executes the action np.array([1.0, -1.0, 0]) in the base environment - """ - - def __init__(self, env: gym.Env): - """A wrapper for clipping continuous actions within the valid bound. - - Args: - env: The environment to apply the wrapper - """ - assert isinstance(env.action_space, Box) - - gym.utils.RecordConstructorArgs.__init__(self) - gym.ActionWrapper.__init__(self, env) - - def action(self, action): - """Clips the action within the valid bounds. - - Args: - action: The action to clip - - Returns: - The clipped action - """ - return np.clip(action, self.action_space.low, self.action_space.high) diff --git a/gymnasium/wrappers/common.py b/gymnasium/wrappers/common.py new file mode 100644 index 000000000..900131531 --- /dev/null +++ b/gymnasium/wrappers/common.py @@ -0,0 +1,536 @@ +"""A collection of common wrappers. + +* ``TimeLimit`` - Provides a time limit on the number of steps for an environment before it truncates +* ``Autoreset`` - Auto-resets the environment +* ``PassiveEnvChecker`` - Passive environment checker that does not modify any environment data +* ``OrderEnforcing`` - Enforces the order of function calls to environments +* ``RecordEpisodeStatistics`` - Records the episode statistics +""" +from __future__ import annotations + +import time +from collections import deque +from copy import deepcopy +from typing import TYPE_CHECKING, Any, SupportsFloat + +import gymnasium as gym +from gymnasium import logger +from gymnasium.core import ActType, ObsType, RenderFrame +from gymnasium.error import ResetNeeded +from gymnasium.utils.passive_env_checker import ( + check_action_space, + check_observation_space, + env_render_passive_checker, + env_reset_passive_checker, + env_step_passive_checker, +) + + +if TYPE_CHECKING: + from gymnasium.envs.registration import EnvSpec + + +__all__ = [ + "TimeLimit", + "Autoreset", + "PassiveEnvChecker", + "OrderEnforcing", + "RecordEpisodeStatistics", +] + + +class TimeLimit( + gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs +): + """Limits the number of steps for an environment through truncating the environment if a maximum number of timesteps is exceeded. + + If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued. + Critically, this is different from the `terminated` signal that originates from the underlying environment as part of the MDP. + No vector wrapper exists. + + Example using the TimeLimit wrapper: + >>> from gymnasium.wrappers import TimeLimit + >>> from gymnasium.envs.classic_control import CartPoleEnv + + >>> spec = gym.spec("CartPole-v1") + >>> spec.max_episode_steps + 500 + >>> env = gym.make("CartPole-v1") + >>> env # TimeLimit is included within the environment stack + >>>> + >>> env.spec # doctest: +ELLIPSIS + EnvSpec(id='CartPole-v1', ..., max_episode_steps=500, ...) + >>> env = gym.make("CartPole-v1", max_episode_steps=3) + >>> env.spec # doctest: +ELLIPSIS + EnvSpec(id='CartPole-v1', ..., max_episode_steps=3, ...) + >>> env = TimeLimit(CartPoleEnv(), max_episode_steps=10) + >>> env + > + + Example of `TimeLimit` determining the episode step + >>> env = gym.make("CartPole-v1", max_episode_steps=3) + >>> _ = env.reset(seed=123) + >>> _ = env.action_space.seed(123) + >>> _, _, terminated, truncated, _ = env.step(env.action_space.sample()) + >>> terminated, truncated + (False, False) + >>> _, _, terminated, truncated, _ = env.step(env.action_space.sample()) + >>> terminated, truncated + (False, False) + >>> _, _, terminated, truncated, _ = env.step(env.action_space.sample()) + >>> terminated, truncated + (False, True) + + Change logs: + * v0.10.6 - Initially added + * v0.25.0 - With the step API update, the termination and truncation signal is returned separately. + """ + + def __init__( + self, + env: gym.Env, + max_episode_steps: int, + ): + """Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur. + + Args: + env: The environment to apply the wrapper + max_episode_steps: An optional max episode steps (if ``None``, ``env.spec.max_episode_steps`` is used) + """ + gym.utils.RecordConstructorArgs.__init__( + self, max_episode_steps=max_episode_steps + ) + gym.Wrapper.__init__(self, env) + + self._max_episode_steps = max_episode_steps + self._elapsed_steps = None + + def step( + self, action: ActType + ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: + """Steps through the environment and if the number of steps elapsed exceeds ``max_episode_steps`` then truncate. + + Args: + action: The environment step action + + Returns: + The environment step ``(observation, reward, terminated, truncated, info)`` with `truncated=True` + if the number of steps elapsed >= max episode steps + + """ + observation, reward, terminated, truncated, info = self.env.step(action) + self._elapsed_steps += 1 + + if self._elapsed_steps >= self._max_episode_steps: + truncated = True + + return observation, reward, terminated, truncated, info + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[ObsType, dict[str, Any]]: + """Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero. + + Args: + seed: Seed for the environment + options: Options for the environment + + Returns: + The reset environment + """ + self._elapsed_steps = 0 + return super().reset(seed=seed, options=options) + + @property + def spec(self) -> EnvSpec | None: + """Modifies the environment spec to include the `max_episode_steps=self._max_episode_steps`.""" + if self._cached_spec is not None: + return self._cached_spec + + env_spec = self.env.spec + if env_spec is not None: + env_spec = deepcopy(env_spec) + env_spec.max_episode_steps = self._max_episode_steps + + self._cached_spec = env_spec + return env_spec + + +class Autoreset( + gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs +): + """The wrapped environment is automatically reset when an terminated or truncated state is reached. + + When calling step causes :meth:`Env.step` to return `terminated=True` or `truncated=True`, :meth:`Env.reset` is called, + and the return format of :meth:`self.step` is as follows: ``(new_obs, final_reward, final_terminated, final_truncated, info)`` + with new step API and ``(new_obs, final_reward, final_done, info)`` with the old step API. + No vector version of the wrapper exists. + + - ``obs`` is the first observation after calling :meth:`self.env.reset` + - ``final_reward`` is the reward after calling :meth:`self.env.step`, prior to calling :meth:`self.env.reset`. + - ``final_terminated`` is the terminated value before calling :meth:`self.env.reset`. + - ``final_truncated`` is the truncated value before calling :meth:`self.env.reset`. Both `final_terminated` and `final_truncated` cannot be False. + - ``info`` is a dict containing all the keys from the info dict returned by the call to :meth:`self.env.reset`, + with an additional key "final_observation" containing the observation returned by the last call to :meth:`self.env.step` + and "final_info" containing the info dict returned by the last call to :meth:`self.env.step`. + + Warning: + When using this wrapper to collect rollouts, note that when :meth:`Env.step` returns `terminated` or `truncated`, a + new observation from after calling :meth:`Env.reset` is returned by :meth:`Env.step` alongside the + final reward, terminated and truncated state from the previous episode. + If you need the final state from the previous episode, you need to retrieve it via the + "final_observation" key in the info dict. + Make sure you know what you're doing if you use this wrapper! + + Change logs: + * v0.24.0 - Initially added as `AutoResetWrapper` + * v1.0.0 - renamed to `Autoreset` and autoreset order was changed to reset on the step after the environment terminates or truncates. As a result, `"final_observation"` and `"final_info"` is removed. + """ + + def __init__(self, env: gym.Env): + """A class for providing an automatic reset functionality for gymnasium environments when calling :meth:`self.step`. + + Args: + env (gym.Env): The environment to apply the wrapper + """ + gym.utils.RecordConstructorArgs.__init__(self) + gym.Wrapper.__init__(self, env) + + def step( + self, action: ActType + ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: + """Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered. + + Args: + action: The action to take + + Returns: + The autoreset environment :meth:`step` + """ + obs, reward, terminated, truncated, info = self.env.step(action) + + if terminated or truncated: + new_obs, new_info = self.env.reset() + + assert ( + "final_observation" not in new_info + ), f'new info dict already contains "final_observation", info keys: {new_info.keys()}' + assert ( + "final_info" not in new_info + ), f'new info dict already contains "final_observation", info keys: {new_info.keys()}' + + new_info["final_observation"] = obs + new_info["final_info"] = info + + obs = new_obs + info = new_info + + return obs, reward, terminated, truncated, info + + +class PassiveEnvChecker( + gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs +): + """A passive wrapper that surrounds the ``step``, ``reset`` and ``render`` functions to check they follow Gymnasium's API. + + This wrapper is automatically applied during make and can be disabled with `disable_env_checker`. + No vector version of the wrapper exists. + + Example: + >>> import gymnasium as gym + >>> env = gym.make("CartPole-v1") + >>> env + >>>> + >>> env = gym.make("CartPole-v1", disable_env_checker=True) + >>> env + >>> + + Change logs: + * v0.24.1 - Initially added however broken in several ways + * v0.25.0 - Bugs was all fixed + * v0.29.0 - Removed warnings for infinite bounds for Box observation and action spaces and inregular bound shapes + """ + + def __init__(self, env: gym.Env[ObsType, ActType]): + """Initialises the wrapper with the environments, run the observation and action space tests.""" + gym.utils.RecordConstructorArgs.__init__(self) + gym.Wrapper.__init__(self, env) + + assert hasattr( + env, "action_space" + ), "The environment must specify an action space. https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/" + check_action_space(env.action_space) + assert hasattr( + env, "observation_space" + ), "The environment must specify an observation space. https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/" + check_observation_space(env.observation_space) + + self.checked_reset: bool = False + self.checked_step: bool = False + self.checked_render: bool = False + self.close_called: bool = False + + def step( + self, action: ActType + ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: + """Steps through the environment that on the first call will run the `passive_env_step_check`.""" + if self.checked_step is False: + self.checked_step = True + return env_step_passive_checker(self.env, action) + else: + return self.env.step(action) + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[ObsType, dict[str, Any]]: + """Resets the environment that on the first call will run the `passive_env_reset_check`.""" + if self.checked_reset is False: + self.checked_reset = True + return env_reset_passive_checker(self.env, seed=seed, options=options) + else: + return self.env.reset(seed=seed, options=options) + + def render(self) -> RenderFrame | list[RenderFrame] | None: + """Renders the environment that on the first call will run the `passive_env_render_check`.""" + if self.checked_render is False: + self.checked_render = True + return env_render_passive_checker(self.env) + else: + return self.env.render() + + @property + def spec(self) -> EnvSpec | None: + """Modifies the environment spec to such that `disable_env_checker=False`.""" + if self._cached_spec is not None: + return self._cached_spec + + env_spec = self.env.spec + if env_spec is not None: + env_spec = deepcopy(env_spec) + env_spec.disable_env_checker = False + + self._cached_spec = env_spec + return env_spec + + def close(self): + """Warns if calling close on a closed environment fails.""" + if not self.close_called: + self.close_called = True + return self.env.close() + else: + try: + return self.env.close() + except Exception as e: + logger.warn( + "Calling `env.close()` on the closed environment should be allowed, but it raised the following exception." + ) + raise e + + +class OrderEnforcing( + gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs +): + """Will produce an error if ``step`` or ``render`` is called before ``reset``. + + No vector version of the wrapper exists. + + Example: + >>> import gymnasium as gym + >>> from gymnasium.wrappers import OrderEnforcing + >>> env = gym.make("CartPole-v1", render_mode="human") + >>> env = OrderEnforcing(env) + >>> env.step(0) + Traceback (most recent call last): + ... + gymnasium.error.ResetNeeded: Cannot call env.step() before calling env.reset() + >>> env.render() + Traceback (most recent call last): + ... + gymnasium.error.ResetNeeded: Cannot call `env.render()` before calling `env.reset()`, if this is an intended action, set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper. + >>> _ = env.reset() + >>> env.render() + >>> _ = env.step(0) + >>> env.close() + + Change logs: + * v0.22.0 - Initially added + * v0.24.0 - Added order enforcing for the render function + """ + + def __init__( + self, + env: gym.Env[ObsType, ActType], + disable_render_order_enforcing: bool = False, + ): + """A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`. + + Args: + env: The environment to wrap + disable_render_order_enforcing: If to disable render order enforcing + """ + gym.utils.RecordConstructorArgs.__init__( + self, disable_render_order_enforcing=disable_render_order_enforcing + ) + gym.Wrapper.__init__(self, env) + + self._has_reset: bool = False + self._disable_render_order_enforcing: bool = disable_render_order_enforcing + + def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict]: + """Steps through the environment.""" + if not self._has_reset: + raise ResetNeeded("Cannot call env.step() before calling env.reset()") + return super().step(action) + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[ObsType, dict[str, Any]]: + """Resets the environment with `kwargs`.""" + self._has_reset = True + return super().reset(seed=seed, options=options) + + def render(self) -> RenderFrame | list[RenderFrame] | None: + """Renders the environment with `kwargs`.""" + if not self._disable_render_order_enforcing and not self._has_reset: + raise ResetNeeded( + "Cannot call `env.render()` before calling `env.reset()`, if this is an intended action, " + "set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper." + ) + return super().render() + + @property + def has_reset(self): + """Returns if the environment has been reset before.""" + return self._has_reset + + @property + def spec(self) -> EnvSpec | None: + """Modifies the environment spec to add the `order_enforce=True`.""" + if self._cached_spec is not None: + return self._cached_spec + + env_spec = self.env.spec + if env_spec is not None: + env_spec = deepcopy(env_spec) + env_spec.order_enforce = True + + self._cached_spec = env_spec + return env_spec + + +class RecordEpisodeStatistics( + gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs +): + """This wrapper will keep track of cumulative rewards and episode lengths. + + At the end of an episode, the statistics of the episode will be added to ``info`` + using the key ``episode``. If using a vectorized environment also the key + ``_episode`` is used which indicates whether the env at the respective index has + the episode statistics. + A vector version of the wrapper exists, :class:`gymnasium.wrappers.vector.RecordEpisodeStatistics`. + + After the completion of an episode, ``info`` will look like this:: + + >>> info = { + ... "episode": { + ... "r": "", + ... "l": "", + ... "t": "" + ... }, + ... } + + For a vectorized environments the output will be in the form of:: + + >>> infos = { + ... "final_observation": "", + ... "_final_observation": "", + ... "final_info": "", + ... "_final_info": "", + ... "episode": { + ... "r": "", + ... "l": "", + ... "t": "" + ... }, + ... "_episode": "" + ... } + + Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via + :attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively. + + Attributes: + * time_queue: The time length of the last ``deque_size``-many episodes + * return_queue: The cumulative rewards of the last ``deque_size``-many episodes + * length_queue: The lengths of the last ``deque_size``-many episodes + + Change logs: + * v0.15.4 - Initially added + * v1.0.0 - Removed vector environment support for `wrappers.vector.RecordEpisodeStatistics` and add attribute ``time_queue`` + """ + + def __init__( + self, + env: gym.Env[ObsType, ActType], + buffer_length: int | None = 100, + stats_key: str = "episode", + ): + """This wrapper will keep track of cumulative rewards and episode lengths. + + Args: + env (Env): The environment to apply the wrapper + buffer_length: The size of the buffers :attr:`return_queue` and :attr:`length_queue` + stats_key: The info key for the episode statistics + """ + gym.utils.RecordConstructorArgs.__init__(self) + gym.Wrapper.__init__(self, env) + + self._stats_key = stats_key + + self.episode_count = 0 + self.episode_start_time: float = -1 + self.episode_returns: float = 0.0 + self.episode_lengths: int = 0 + + self.time_queue: deque[float] = deque(maxlen=buffer_length) + self.return_queue: deque[float] = deque(maxlen=buffer_length) + self.length_queue: deque[int] = deque(maxlen=buffer_length) + + def step( + self, action: ActType + ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: + """Steps through the environment, recording the episode statistics.""" + obs, reward, terminated, truncated, info = super().step(action) + + self.episode_returns += reward + self.episode_lengths += 1 + + if terminated or truncated: + assert self._stats_key not in info + + episode_time_length = round( + time.perf_counter() - self.episode_start_time, 6 + ) + info[self._stats_key] = { + "r": self.episode_returns, + "l": self.episode_lengths, + "t": episode_time_length, + } + + self.time_queue.append(episode_time_length) + self.return_queue.append(self.episode_returns) + self.length_queue.append(self.episode_lengths) + + self.episode_count += 1 + + return obs, reward, terminated, truncated, info + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[ObsType, dict[str, Any]]: + """Resets the environment using seed and options and resets the episode rewards and lengths.""" + obs, info = super().reset(seed=seed, options=options) + + self.episode_start_time = time.perf_counter() + self.episode_returns = 0.0 + self.episode_lengths = 0 + + return obs, info diff --git a/gymnasium/wrappers/compatibility.py b/gymnasium/wrappers/compatibility.py deleted file mode 100644 index 0d7b8a5aa..000000000 --- a/gymnasium/wrappers/compatibility.py +++ /dev/null @@ -1,129 +0,0 @@ -"""A compatibility wrapper converting an old-style environment into a valid environment.""" -from typing import Any, Dict, Optional, Protocol, Tuple, runtime_checkable - -import gymnasium as gym -from gymnasium import logger -from gymnasium.core import ObsType -from gymnasium.utils.step_api_compatibility import ( - convert_to_terminated_truncated_step_api, -) - - -@runtime_checkable -class LegacyEnv(Protocol): - """A protocol for environments using the old step API.""" - - observation_space: gym.Space - action_space: gym.Space - - def reset(self) -> Any: - """Reset the environment and return the initial observation.""" - ... - - def step(self, action: Any) -> Tuple[Any, float, bool, Dict]: - """Run one timestep of the environment's dynamics.""" - ... - - def render(self, mode: Optional[str] = "human") -> Any: - """Render the environment.""" - ... - - def close(self): - """Close the environment.""" - ... - - def seed(self, seed: Optional[int] = None): - """Set the seed for this env's random number generator(s).""" - ... - - -class EnvCompatibility(gym.Env): - r"""A wrapper which can transform an environment from the old API to the new API. - - Old step API refers to step() method returning (observation, reward, done, info), and reset() only retuning the observation. - New step API refers to step() method returning (observation, reward, terminated, truncated, info) and reset() returning (observation, info). - (Refer to docs for details on the API change) - - Known limitations: - - Environments that use `self.np_random` might not work as expected. - """ - - def __init__(self, old_env: LegacyEnv, render_mode: Optional[str] = None): - """A wrapper which converts old-style envs to valid modern envs. - - Some information may be lost in the conversion, so we recommend updating your environment. - - Args: - old_env (LegacyEnv): the env to wrap, implemented with the old API - render_mode (str): the render mode to use when rendering the environment, passed automatically to env.render - """ - logger.deprecation( - "The `gymnasium.make(..., apply_api_compatibility=...)` parameter is deprecated and will be removed in v1.0. " - "Instead use `gymnasium.make('GymV21Environment-v0', env_name=...)` or `from shimmy import GymV21CompatibilityV0`" - ) - - self.env = old_env - self.metadata = getattr(old_env, "metadata", {"render_modes": []}) - self.render_mode = render_mode - self.reward_range = getattr(old_env, "reward_range", None) - self.spec = getattr(old_env, "spec", None) - - self.observation_space = old_env.observation_space - self.action_space = old_env.action_space - - def reset( - self, seed: Optional[int] = None, options: Optional[dict] = None - ) -> Tuple[ObsType, dict]: - """Resets the environment. - - Args: - seed: the seed to reset the environment with - options: the options to reset the environment with - - Returns: - (observation, info) - """ - if seed is not None: - self.env.seed(seed) - # Options are ignored - - if self.render_mode == "human": - self.render() - - return self.env.reset(), {} - - def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]: - """Steps through the environment. - - Args: - action: action to step through the environment with - - Returns: - (observation, reward, terminated, truncated, info) - """ - obs, reward, done, info = self.env.step(action) - - if self.render_mode == "human": - self.render() - - return convert_to_terminated_truncated_step_api((obs, reward, done, info)) - - def render(self) -> Any: - """Renders the environment. - - Returns: - The rendering of the environment, depending on the render mode - """ - return self.env.render(mode=self.render_mode) - - def close(self): - """Closes the environment.""" - self.env.close() - - def __str__(self): - """Returns the wrapper name and the unwrapped environment string.""" - return f"<{type(self).__name__}{self.env}>" - - def __repr__(self): - """Returns the string representation of the wrapper.""" - return str(self) diff --git a/gymnasium/wrappers/env_checker.py b/gymnasium/wrappers/env_checker.py deleted file mode 100644 index f1961a74d..000000000 --- a/gymnasium/wrappers/env_checker.py +++ /dev/null @@ -1,95 +0,0 @@ -"""A passive environment checker wrapper for an environment's observation and action space along with the reset, step and render functions.""" -from __future__ import annotations - -from copy import deepcopy -from typing import TYPE_CHECKING - -import gymnasium as gym -from gymnasium import logger -from gymnasium.core import ActType -from gymnasium.utils.passive_env_checker import ( - check_action_space, - check_observation_space, - env_render_passive_checker, - env_reset_passive_checker, - env_step_passive_checker, -) - - -if TYPE_CHECKING: - from gymnasium.envs.registration import EnvSpec - - -class PassiveEnvChecker(gym.Wrapper, gym.utils.RecordConstructorArgs): - """A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gymnasium API.""" - - def __init__(self, env): - """Initialises the wrapper with the environments, run the observation and action space tests.""" - gym.utils.RecordConstructorArgs.__init__(self) - gym.Wrapper.__init__(self, env) - - assert hasattr( - env, "action_space" - ), "The environment must specify an action space. https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/" - check_action_space(env.action_space) - assert hasattr( - env, "observation_space" - ), "The environment must specify an observation space. https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/" - check_observation_space(env.observation_space) - - self.checked_reset = False - self.checked_step = False - self.checked_render = False - self.close_called = False - - def step(self, action: ActType): - """Steps through the environment that on the first call will run the `passive_env_step_check`.""" - if not self.checked_step: - self.checked_step = True - return env_step_passive_checker(self.env, action) - else: - return self.env.step(action) - - def reset(self, **kwargs): - """Resets the environment that on the first call will run the `passive_env_reset_check`.""" - if not self.checked_reset: - self.checked_reset = True - return env_reset_passive_checker(self.env, **kwargs) - else: - return self.env.reset(**kwargs) - - def render(self, *args, **kwargs): - """Renders the environment that on the first call will run the `passive_env_render_check`.""" - if not self.checked_render: - self.checked_render = True - return env_render_passive_checker(self.env, *args, **kwargs) - else: - return self.env.render(*args, **kwargs) - - @property - def spec(self) -> EnvSpec | None: - """Modifies the environment spec to such that `disable_env_checker=False`.""" - if self._cached_spec is not None: - return self._cached_spec - - env_spec = self.env.spec - if env_spec is not None: - env_spec = deepcopy(env_spec) - env_spec.disable_env_checker = False - - self._cached_spec = env_spec - return env_spec - - def close(self): - """Warns if calling close on a closed environment fails.""" - if not self.close_called: - self.close_called = True - return self.env.close() - else: - try: - return self.env.close() - except Exception as e: - logger.warn( - "Calling `env.close()` on the closed environment should be allowed, but it raised the following exception." - ) - raise e diff --git a/gymnasium/wrappers/filter_observation.py b/gymnasium/wrappers/filter_observation.py deleted file mode 100644 index dc905a805..000000000 --- a/gymnasium/wrappers/filter_observation.py +++ /dev/null @@ -1,92 +0,0 @@ -"""A wrapper for filtering dictionary observations by their keys.""" -import copy -from typing import Sequence - -import gymnasium as gym -from gymnasium import spaces - - -class FilterObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): - """Filter Dict observation space by the keys. - - Example: - >>> import gymnasium as gym - >>> from gymnasium.wrappers import TransformObservation - >>> env = gym.make("CartPole-v1") - >>> env = TransformObservation(env, lambda obs: {'obs': obs, 'time': 0}) - >>> env.observation_space = gym.spaces.Dict(obs=env.observation_space, time=gym.spaces.Discrete(1)) - >>> env.reset(seed=42) - ({'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), 'time': 0}, {}) - >>> env = FilterObservation(env, filter_keys=['obs']) - >>> env.reset(seed=42) - ({'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32)}, {}) - >>> env.step(0) - ({'obs': array([ 0.02727336, -0.20172954, 0.03625453, 0.32351476], dtype=float32)}, 1.0, False, False, {}) - """ - - def __init__(self, env: gym.Env, filter_keys: Sequence[str] = None): - """A wrapper that filters dictionary observations by their keys. - - Args: - env: The environment to apply the wrapper - filter_keys: List of keys to be included in the observations. If ``None``, observations will not be filtered and this wrapper has no effect - - Raises: - ValueError: If the environment's observation space is not :class:`spaces.Dict` - ValueError: If any of the `filter_keys` are not included in the original `env`'s observation space - """ - gym.utils.RecordConstructorArgs.__init__(self, filter_keys=filter_keys) - gym.ObservationWrapper.__init__(self, env) - - wrapped_observation_space = env.observation_space - if not isinstance(wrapped_observation_space, spaces.Dict): - raise ValueError( - f"FilterObservationWrapper is only usable with dict observations, " - f"environment observation space is {type(wrapped_observation_space)}" - ) - - observation_keys = wrapped_observation_space.spaces.keys() - if filter_keys is None: - filter_keys = tuple(observation_keys) - - missing_keys = {key for key in filter_keys if key not in observation_keys} - if missing_keys: - raise ValueError( - "All the filter_keys must be included in the original observation space.\n" - f"Filter keys: {filter_keys}\n" - f"Observation keys: {observation_keys}\n" - f"Missing keys: {missing_keys}" - ) - - self.observation_space = type(wrapped_observation_space)( - [ - (name, copy.deepcopy(space)) - for name, space in wrapped_observation_space.spaces.items() - if name in filter_keys - ] - ) - - self._env = env - self._filter_keys = tuple(filter_keys) - - def observation(self, observation): - """Filters the observations. - - Args: - observation: The observation to filter - - Returns: - The filtered observations - """ - filter_observation = self._filter_observation(observation) - return filter_observation - - def _filter_observation(self, observation): - observation = type(observation)( - [ - (name, value) - for name, value in observation.items() - if name in self._filter_keys - ] - ) - return observation diff --git a/gymnasium/wrappers/flatten_observation.py b/gymnasium/wrappers/flatten_observation.py deleted file mode 100644 index 22d5aa2af..000000000 --- a/gymnasium/wrappers/flatten_observation.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Wrapper for flattening observations of an environment.""" -import gymnasium as gym -from gymnasium import spaces - - -class FlattenObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): - """Observation wrapper that flattens the observation. - - Example: - >>> import gymnasium as gym - >>> from gymnasium.wrappers import FlattenObservation - >>> env = gym.make("CarRacing-v2") - >>> env.observation_space.shape - (96, 96, 3) - >>> env = FlattenObservation(env) - >>> env.observation_space.shape - (27648,) - >>> obs, _ = env.reset() - >>> obs.shape - (27648,) - """ - - def __init__(self, env: gym.Env): - """Flattens the observations of an environment. - - Args: - env: The environment to apply the wrapper - """ - gym.utils.RecordConstructorArgs.__init__(self) - gym.ObservationWrapper.__init__(self, env) - - self.observation_space = spaces.flatten_space(env.observation_space) - - def observation(self, observation): - """Flattens an observation. - - Args: - observation: The observation to flatten - - Returns: - The flattened observation - """ - return spaces.flatten(self.env.observation_space, observation) diff --git a/gymnasium/wrappers/frame_stack.py b/gymnasium/wrappers/frame_stack.py deleted file mode 100644 index 1fbc37453..000000000 --- a/gymnasium/wrappers/frame_stack.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Wrapper that stacks frames.""" -from collections import deque -from typing import Union - -import numpy as np - -import gymnasium as gym -from gymnasium.error import DependencyNotInstalled -from gymnasium.spaces import Box - - -class LazyFrames: - """Ensures common frames are only stored once to optimize memory use. - - To further reduce the memory use, it is optionally to turn on lz4 to compress the observations. - - Note: - This object should only be converted to numpy array just before forward pass. - """ - - __slots__ = ("frame_shape", "dtype", "shape", "lz4_compress", "_frames") - - def __init__(self, frames: list, lz4_compress: bool = False): - """Lazyframe for a set of frames and if to apply lz4. - - Args: - frames (list): The frames to convert to lazy frames - lz4_compress (bool): Use lz4 to compress the frames internally - - Raises: - DependencyNotInstalled: lz4 is not installed - """ - self.frame_shape = tuple(frames[0].shape) - self.shape = (len(frames),) + self.frame_shape - self.dtype = frames[0].dtype - if lz4_compress: - try: - from lz4.block import compress - except ImportError as e: - raise DependencyNotInstalled( - "lz4 is not installed, run `pip install gymnasium[other]`" - ) from e - - frames = [compress(frame) for frame in frames] - self._frames = frames - self.lz4_compress = lz4_compress - - def __array__(self, dtype=None): - """Gets a numpy array of stacked frames with specific dtype. - - Args: - dtype: The dtype of the stacked frames - - Returns: - The array of stacked frames with dtype - """ - arr = self[:] - if dtype is not None: - return arr.astype(dtype) - return arr - - def __len__(self): - """Returns the number of frame stacks. - - Returns: - The number of frame stacks - """ - return self.shape[0] - - def __getitem__(self, int_or_slice: Union[int, slice]): - """Gets the stacked frames for a particular index or slice. - - Args: - int_or_slice: Index or slice to get items for - - Returns: - np.stacked frames for the int or slice - - """ - if isinstance(int_or_slice, int): - return self._check_decompress(self._frames[int_or_slice]) # single frame - return np.stack( - [self._check_decompress(f) for f in self._frames[int_or_slice]], axis=0 - ) - - def __eq__(self, other): - """Checks that the current frames are equal to the other object.""" - return self.__array__() == other - - def _check_decompress(self, frame): - if self.lz4_compress: - from lz4.block import decompress - - return np.frombuffer(decompress(frame), dtype=self.dtype).reshape( - self.frame_shape - ) - return frame - - -class FrameStack(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): - """Observation wrapper that stacks the observations in a rolling manner. - - For example, if the number of stacks is 4, then the returned observation contains - the most recent 4 observations. For environment 'Pendulum-v1', the original observation - is an array with shape [3], so if we stack 4 observations, the processed observation - has shape [4, 3]. - - Note: - - To be memory efficient, the stacked observations are wrapped by :class:`LazyFrame`. - - The observation space must be :class:`Box` type. If one uses :class:`Dict` - as observation space, it should apply :class:`FlattenObservation` wrapper first. - - After :meth:`reset` is called, the frame buffer will be filled with the initial observation. - I.e. the observation returned by :meth:`reset` will consist of `num_stack` many identical frames. - - Example: - >>> import gymnasium as gym - >>> from gymnasium.wrappers import FrameStack - >>> env = gym.make("CarRacing-v2") - >>> env = FrameStack(env, 4) - >>> env.observation_space - Box(0, 255, (4, 96, 96, 3), uint8) - >>> obs, _ = env.reset() - >>> obs.shape - (4, 96, 96, 3) - """ - - def __init__( - self, - env: gym.Env, - num_stack: int, - lz4_compress: bool = False, - ): - """Observation wrapper that stacks the observations in a rolling manner. - - Args: - env (Env): The environment to apply the wrapper - num_stack (int): The number of frames to stack - lz4_compress (bool): Use lz4 to compress the frames internally - """ - gym.utils.RecordConstructorArgs.__init__( - self, num_stack=num_stack, lz4_compress=lz4_compress - ) - gym.ObservationWrapper.__init__(self, env) - - self.num_stack = num_stack - self.lz4_compress = lz4_compress - - self.frames = deque(maxlen=num_stack) - - low = np.repeat(self.observation_space.low[np.newaxis, ...], num_stack, axis=0) - high = np.repeat( - self.observation_space.high[np.newaxis, ...], num_stack, axis=0 - ) - self.observation_space = Box( - low=low, high=high, dtype=self.observation_space.dtype - ) - - def observation(self, observation): - """Converts the wrappers current frames to lazy frames. - - Args: - observation: Ignored - - Returns: - :class:`LazyFrames` object for the wrapper's frame buffer, :attr:`self.frames` - """ - assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack) - return LazyFrames(list(self.frames), self.lz4_compress) - - def step(self, action): - """Steps through the environment, appending the observation to the frame buffer. - - Args: - action: The action to step through the environment with - - Returns: - Stacked observations, reward, terminated, truncated, and information from the environment - """ - observation, reward, terminated, truncated, info = self.env.step(action) - self.frames.append(observation) - return self.observation(None), reward, terminated, truncated, info - - def reset(self, **kwargs): - """Reset the environment with kwargs. - - Args: - **kwargs: The kwargs for the environment reset - - Returns: - The stacked observations - """ - obs, info = self.env.reset(**kwargs) - - [self.frames.append(obs) for _ in range(self.num_stack)] - - return self.observation(None), info diff --git a/gymnasium/wrappers/gray_scale_observation.py b/gymnasium/wrappers/gray_scale_observation.py deleted file mode 100644 index 81172781b..000000000 --- a/gymnasium/wrappers/gray_scale_observation.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Wrapper that converts a color observation to grayscale.""" -import numpy as np - -import gymnasium as gym -from gymnasium.spaces import Box - - -class GrayScaleObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): - """Convert the image observation from RGB to gray scale. - - Example: - >>> import gymnasium as gym - >>> from gymnasium.wrappers import GrayScaleObservation - >>> env = gym.make("CarRacing-v2") - >>> env.observation_space - Box(0, 255, (96, 96, 3), uint8) - >>> env = GrayScaleObservation(gym.make("CarRacing-v2")) - >>> env.observation_space - Box(0, 255, (96, 96), uint8) - >>> env = GrayScaleObservation(gym.make("CarRacing-v2"), keep_dim=True) - >>> env.observation_space - Box(0, 255, (96, 96, 1), uint8) - """ - - def __init__(self, env: gym.Env, keep_dim: bool = False): - """Convert the image observation from RGB to gray scale. - - Args: - env (Env): The environment to apply the wrapper - keep_dim (bool): If `True`, a singleton dimension will be added, i.e. observations are of the shape AxBx1. - Otherwise, they are of shape AxB. - """ - gym.utils.RecordConstructorArgs.__init__(self, keep_dim=keep_dim) - gym.ObservationWrapper.__init__(self, env) - - self.keep_dim = keep_dim - - assert ( - isinstance(self.observation_space, Box) - and len(self.observation_space.shape) == 3 - and self.observation_space.shape[-1] == 3 - ) - - obs_shape = self.observation_space.shape[:2] - if self.keep_dim: - self.observation_space = Box( - low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8 - ) - else: - self.observation_space = Box( - low=0, high=255, shape=obs_shape, dtype=np.uint8 - ) - - def observation(self, observation): - """Converts the colour observation to greyscale. - - Args: - observation: Color observations - - Returns: - Grayscale observations - """ - import cv2 - - observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY) - if self.keep_dim: - observation = np.expand_dims(observation, -1) - return observation diff --git a/gymnasium/wrappers/human_rendering.py b/gymnasium/wrappers/human_rendering.py deleted file mode 100644 index d7c23a106..000000000 --- a/gymnasium/wrappers/human_rendering.py +++ /dev/null @@ -1,142 +0,0 @@ -"""A wrapper that adds human-renering functionality to an environment.""" -import copy - -import numpy as np - -import gymnasium as gym -from gymnasium.error import DependencyNotInstalled - - -class HumanRendering(gym.Wrapper, gym.utils.RecordConstructorArgs): - """Performs human rendering for an environment that only supports "rgb_array"rendering. - - This wrapper is particularly useful when you have implemented an environment that can produce - RGB images but haven't implemented any code to render the images to the screen. - If you want to use this wrapper with your environments, remember to specify ``"render_fps"`` - in the metadata of your environment. - - The ``render_mode`` of the wrapped environment must be either ``'rgb_array'`` or ``'rgb_array_list'``. - - Example: - >>> import gymnasium as gym - >>> from gymnasium.wrappers import HumanRendering - >>> env = gym.make("LunarLander-v2", render_mode="rgb_array") - >>> wrapped = HumanRendering(env) - >>> obs, _ = wrapped.reset() # This will start rendering to the screen - - The wrapper can also be applied directly when the environment is instantiated, simply by passing - ``render_mode="human"`` to ``make``. The wrapper will only be applied if the environment does not - implement human-rendering natively (i.e. ``render_mode`` does not contain ``"human"``). - - >>> env = gym.make("phys2d/CartPole-v1", render_mode="human") # phys2d/CartPole-v1 doesn't implement human-rendering natively - >>> obs, _ = env.reset() # This will start rendering to the screen - - Warning: If the base environment uses ``render_mode="rgb_array_list"``, its (i.e. the *base environment's*) render method - will always return an empty list: - - >>> env = gym.make("LunarLander-v2", render_mode="rgb_array_list") - >>> wrapped = HumanRendering(env) - >>> obs, _ = wrapped.reset() - >>> env.render() # env.render() will always return an empty list! - [] - """ - - def __init__(self, env): - """Initialize a :class:`HumanRendering` instance. - - Args: - env: The environment that is being wrapped - """ - gym.utils.RecordConstructorArgs.__init__(self) - gym.Wrapper.__init__(self, env) - - assert env.render_mode in [ - "rgb_array", - "rgb_array_list", - ], f"Expected env.render_mode to be one of 'rgb_array' or 'rgb_array_list' but got '{env.render_mode}'" - assert ( - "render_fps" in env.metadata - ), "The base environment must specify 'render_fps' to be used with the HumanRendering wrapper" - - self.screen_size = None - self.window = None - self.clock = None - - self.metadata = copy.deepcopy(self.env.metadata) - if "human" not in self.metadata["render_modes"]: - self.metadata["render_modes"].append("human") - - gym.utils.RecordConstructorArgs.__init__(self) - - @property - def render_mode(self): - """Always returns ``'human'``.""" - return "human" - - def step(self, *args, **kwargs): - """Perform a step in the base environment and render a frame to the screen.""" - result = self.env.step(*args, **kwargs) - self._render_frame() - return result - - def reset(self, *args, **kwargs): - """Reset the base environment and render a frame to the screen.""" - result = self.env.reset(*args, **kwargs) - self._render_frame() - return result - - def render(self): - """This method doesn't do much, actual rendering is performed in :meth:`step` and :meth:`reset`.""" - return None - - def _render_frame(self): - """Fetch the last frame from the base environment and render it to the screen.""" - try: - import pygame - except ImportError as e: - raise DependencyNotInstalled( - "pygame is not installed, run `pip install gymnasium[box2d]`" - ) from e - if self.env.render_mode == "rgb_array_list": - last_rgb_array = self.env.render() - assert isinstance(last_rgb_array, list) - last_rgb_array = last_rgb_array[-1] - elif self.env.render_mode == "rgb_array": - last_rgb_array = self.env.render() - else: - raise Exception( - f"Wrapped environment must have mode 'rgb_array' or 'rgb_array_list', actual render mode: {self.env.render_mode}" - ) - assert isinstance(last_rgb_array, np.ndarray) - - rgb_array = np.transpose(last_rgb_array, axes=(1, 0, 2)) - - if self.screen_size is None: - self.screen_size = rgb_array.shape[:2] - - assert ( - self.screen_size == rgb_array.shape[:2] - ), f"The shape of the rgb array has changed from {self.screen_size} to {rgb_array.shape[:2]}" - - if self.window is None: - pygame.init() - pygame.display.init() - self.window = pygame.display.set_mode(self.screen_size) - - if self.clock is None: - self.clock = pygame.time.Clock() - - surf = pygame.surfarray.make_surface(rgb_array) - self.window.blit(surf, (0, 0)) - pygame.event.pump() - self.clock.tick(self.metadata["render_fps"]) - pygame.display.flip() - - def close(self): - """Close the rendering window.""" - super().close() - if self.window is not None: - import pygame - - pygame.display.quit() - pygame.quit() diff --git a/gymnasium/experimental/wrappers/jax_to_numpy.py b/gymnasium/wrappers/jax_to_numpy.py similarity index 78% rename from gymnasium/experimental/wrappers/jax_to_numpy.py rename to gymnasium/wrappers/jax_to_numpy.py index 12db96631..807f19522 100644 --- a/gymnasium/experimental/wrappers/jax_to_numpy.py +++ b/gymnasium/wrappers/jax_to_numpy.py @@ -21,7 +21,7 @@ except ImportError: "Jax is not installed therefore cannot call `numpy_to_jax`, run `pip install gymnasium[jax]`" ) -__all__ = ["JaxToNumpyV0", "jax_to_numpy", "numpy_to_jax"] +__all__ = ["JaxToNumpy", "jax_to_numpy", "numpy_to_jax"] @functools.singledispatch @@ -92,17 +92,39 @@ def _iterable_jax_to_numpy( return type(value)(jax_to_numpy(v) for v in value) -class JaxToNumpyV0( +class JaxToNumpy( gym.Wrapper[WrapperObsType, WrapperActType, ObsType, ActType], gym.utils.RecordConstructorArgs, ): - """Wraps a jax environment so that it can be interacted with through numpy arrays. + """Wraps a Jax-based environment such that it can be interacted with NumPy arrays. Actions must be provided as numpy arrays and observations will be returned as numpy arrays. + A vector version of the wrapper exists, :class:`gymnasium.wrappers.vector.JaxToNumpy`. Notes: The Jax To Numpy and Numpy to Jax conversion does not guarantee a roundtrip (jax -> numpy -> jax) and vice versa. The reason for this is jax does not support non-array values, therefore numpy ``int_32(5) -> DeviceArray([5], dtype=jnp.int23)`` + + Example: + >>> import gymnasium as gym # doctest: +SKIP + >>> env = gym.make("JaxEnv-vx") # doctest: +SKIP + >>> env = JaxToNumpy(env) # doctest: +SKIP + >>> obs, _ = env.reset(seed=123) # doctest: +SKIP + >>> type(obs) # doctest: +SKIP + + >>> action = env.action_space.sample() # doctest: +SKIP + >>> obs, reward, terminated, truncated, info = env.step(action) # doctest: +SKIP + >>> type(obs) # doctest: +SKIP + + >>> type(reward) # doctest: +SKIP + + >>> type(terminated) # doctest: +SKIP + + >>> type(truncated) # doctest: +SKIP + + + Change logs: + * v1.0.0 - Initially added """ def __init__(self, env: gym.Env[ObsType, ActType]): diff --git a/gymnasium/experimental/wrappers/jax_to_torch.py b/gymnasium/wrappers/jax_to_torch.py similarity index 79% rename from gymnasium/experimental/wrappers/jax_to_torch.py rename to gymnasium/wrappers/jax_to_torch.py index 380c505bc..ce813fe7f 100644 --- a/gymnasium/experimental/wrappers/jax_to_torch.py +++ b/gymnasium/wrappers/jax_to_torch.py @@ -17,7 +17,7 @@ from typing import Any, Iterable, Mapping, SupportsFloat, Union import gymnasium as gym from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType from gymnasium.error import DependencyNotInstalled -from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy +from gymnasium.wrappers.jax_to_numpy import jax_to_numpy try: @@ -40,7 +40,7 @@ except ImportError: ) -__all__ = ["JaxToTorchV0", "jax_to_torch", "torch_to_jax", "Device"] +__all__ = ["JaxToTorch", "jax_to_torch", "torch_to_jax", "Device"] @functools.singledispatch @@ -114,13 +114,36 @@ def _jax_iterable_to_torch( return type(value)(jax_to_torch(v, device) for v in value) -class JaxToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs): - """Wraps a Jax-based environment so that it can be interacted with through PyTorch Tensors. +class JaxToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs): + """Wraps a Jax-based environment so that it can be interacted with PyTorch Tensors. Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors. + A vector version of the wrapper exists, :class:`gymnasium.wrappers.vector.JaxToTorch`. Note: For ``rendered`` this is returned as a NumPy array not a pytorch Tensor. + + Example: + >>> import torch # doctest: +SKIP + >>> import gymnasium as gym # doctest: +SKIP + >>> env = gym.make("JaxEnv-vx") # doctest: +SKIP + >>> env = JaxtoTorch(env) # doctest: +SKIP + >>> obs, _ = env.reset(seed=123) # doctest: +SKIP + >>> type(obs) # doctest: +SKIP + + >>> action = torch.tensor(env.action_space.sample()) # doctest: +SKIP + >>> obs, reward, terminated, truncated, info = env.step(action) # doctest: +SKIP + >>> type(obs) # doctest: +SKIP + + >>> type(reward) # doctest: +SKIP + + >>> type(terminated) # doctest: +SKIP + + >>> type(truncated) # doctest: +SKIP + + + Change logs: + * v1.0.0 - Initially added """ def __init__(self, env: gym.Env, device: Device | None = None): diff --git a/gymnasium/wrappers/monitoring/__init__.py b/gymnasium/wrappers/monitoring/__init__.py deleted file mode 100644 index 589c74c7e..000000000 --- a/gymnasium/wrappers/monitoring/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Module for monitoring.video_recorder.""" diff --git a/gymnasium/wrappers/monitoring/video_recorder.py b/gymnasium/wrappers/monitoring/video_recorder.py deleted file mode 100644 index cd76cd104..000000000 --- a/gymnasium/wrappers/monitoring/video_recorder.py +++ /dev/null @@ -1,178 +0,0 @@ -"""A wrapper for video recording environments by rolling it out, frame by frame.""" -import json -import os -import os.path -import tempfile -from typing import List, Optional - -from gymnasium import error, logger - - -class VideoRecorder: - """VideoRecorder renders a nice movie of a rollout, frame by frame. - - It comes with an ``enabled`` option, so you can still use the same code on episodes where you don't want to record video. - - Note: - You are responsible for calling :meth:`close` on a created VideoRecorder, or else you may leak an encoder process. - """ - - def __init__( - self, - env, - path: Optional[str] = None, - metadata: Optional[dict] = None, - enabled: bool = True, - base_path: Optional[str] = None, - disable_logger: bool = False, - ): - """Video recorder renders a nice movie of a rollout, frame by frame. - - Args: - env (Env): Environment to take video of. - path (Optional[str]): Path to the video file; will be randomly chosen if omitted. - metadata (Optional[dict]): Contents to save to the metadata file. - enabled (bool): Whether to actually record video, or just no-op (for convenience) - base_path (Optional[str]): Alternatively, path to the video file without extension, which will be added. - disable_logger (bool): Whether to disable moviepy logger or not. - - Raises: - Error: You can pass at most one of `path` or `base_path` - Error: Invalid path given that must have a particular file extension - """ - self._async = env.metadata.get("semantics.async") - self.enabled = enabled - self.disable_logger = disable_logger - self._closed = False - - self.render_history = [] - self.env = env - - self.render_mode = env.render_mode - - try: - # check that moviepy is now installed - import moviepy # noqa: F401 - except ImportError as e: - raise error.DependencyNotInstalled( - "moviepy is not installed, run `pip install moviepy`" - ) from e - - if self.render_mode in {None, "human", "ansi", "ansi_list"}: - raise ValueError( - f"Render mode is {self.render_mode}, which is incompatible with" - f" RecordVideo. Initialize your environment with a render_mode" - f" that returns an image, such as rgb_array." - ) - - # Don't bother setting anything else if not enabled - if not self.enabled: - return - - if path is not None and base_path is not None: - raise error.Error("You can pass at most one of `path` or `base_path`.") - - required_ext = ".mp4" - if path is None: - if base_path is not None: - # Base path given, append ext - path = base_path + required_ext - else: - # Otherwise, just generate a unique filename - with tempfile.NamedTemporaryFile(suffix=required_ext) as f: - path = f.name - self.path = path - - path_base, actual_ext = os.path.splitext(self.path) - - if actual_ext != required_ext: - raise error.Error( - f"Invalid path given: {self.path} -- must have file extension {required_ext}." - ) - - self.frames_per_sec = env.metadata.get("render_fps", 30) - - self.broken = False - - # Dump metadata - self.metadata = metadata or {} - self.metadata["content_type"] = "video/mp4" - self.metadata_path = f"{path_base}.meta.json" - self.write_metadata() - - logger.info(f"Starting new video recorder writing to {self.path}") - self.recorded_frames = [] - - @property - def functional(self): - """Returns if the video recorder is functional, is enabled and not broken.""" - return self.enabled and not self.broken - - def capture_frame(self): - """Render the given `env` and add the resulting frame to the video.""" - frame = self.env.render() - if isinstance(frame, List): - self.render_history += frame - frame = frame[-1] - - if not self.functional: - return - if self._closed: - logger.warn( - "The video recorder has been closed and no frames will be captured anymore." - ) - return - logger.debug("Capturing video frame: path=%s", self.path) - - if frame is None: - if self._async: - return - else: - # Indicates a bug in the environment: don't want to raise - # an error here. - logger.warn( - "Env returned None on `render()`. Disabling further rendering for video recorder by marking as " - f"disabled: path={self.path} metadata_path={self.metadata_path}" - ) - self.broken = True - else: - self.recorded_frames.append(frame) - - def close(self): - """Flush all data to disk and close any open frame encoders.""" - if not self.enabled or self._closed: - return - - # Close the encoder - if len(self.recorded_frames) > 0: - try: - from moviepy.video.io.ImageSequenceClip import ImageSequenceClip - except ImportError as e: - raise error.DependencyNotInstalled( - "moviepy is not installed, run `pip install moviepy`" - ) from e - - clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec) - moviepy_logger = None if self.disable_logger else "bar" - clip.write_videofile(self.path, logger=moviepy_logger) - else: - # No frames captured. Set metadata. - if self.metadata is None: - self.metadata = {} - self.metadata["empty"] = True - - self.write_metadata() - - # Stop tracking this for autoclose - self._closed = True - - def write_metadata(self): - """Writes metadata to metadata path.""" - with open(self.metadata_path, "w") as f: - json.dump(self.metadata, f) - - def __del__(self): - """Closes the environment correctly when the recorder is deleted.""" - # Make sure we've closed up shop when garbage collecting - if not self._closed: - logger.warn("Unable to save last video! Did you call close()?") diff --git a/gymnasium/wrappers/normalize.py b/gymnasium/wrappers/normalize.py deleted file mode 100644 index 8e9632c2f..000000000 --- a/gymnasium/wrappers/normalize.py +++ /dev/null @@ -1,155 +0,0 @@ -"""Set of wrappers for normalizing actions and observations.""" -import numpy as np - -import gymnasium as gym - - -class RunningMeanStd: - """Tracks the mean, variance and count of values.""" - - # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm - def __init__(self, epsilon=1e-4, shape=()): - """Tracks the mean, variance and count of values.""" - self.mean = np.zeros(shape, "float64") - self.var = np.ones(shape, "float64") - self.count = epsilon - - def update(self, x): - """Updates the mean, var and count from a batch of samples.""" - batch_mean = np.mean(x, axis=0) - batch_var = np.var(x, axis=0) - batch_count = x.shape[0] - self.update_from_moments(batch_mean, batch_var, batch_count) - - def update_from_moments(self, batch_mean, batch_var, batch_count): - """Updates from batch mean, variance and count moments.""" - self.mean, self.var, self.count = update_mean_var_count_from_moments( - self.mean, self.var, self.count, batch_mean, batch_var, batch_count - ) - - -def update_mean_var_count_from_moments( - mean, var, count, batch_mean, batch_var, batch_count -): - """Updates the mean, var and count using the previous mean, var, count and batch values.""" - delta = batch_mean - mean - tot_count = count + batch_count - - new_mean = mean + delta * batch_count / tot_count - m_a = var * count - m_b = batch_var * batch_count - M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count - new_var = M2 / tot_count - new_count = tot_count - - return new_mean, new_var, new_count - - -class NormalizeObservation(gym.Wrapper, gym.utils.RecordConstructorArgs): - """This wrapper will normalize observations s.t. each coordinate is centered with unit variance. - - Note: - The normalization depends on past trajectories and observations will not be normalized correctly if the wrapper was - newly instantiated or the policy was changed recently. - """ - - def __init__(self, env: gym.Env, epsilon: float = 1e-8): - """This wrapper will normalize observations s.t. each coordinate is centered with unit variance. - - Args: - env (Env): The environment to apply the wrapper - epsilon: A stability parameter that is used when scaling the observations. - """ - gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon) - gym.Wrapper.__init__(self, env) - - try: - self.num_envs = self.get_wrapper_attr("num_envs") - self.is_vector_env = self.get_wrapper_attr("is_vector_env") - except AttributeError: - self.num_envs = 1 - self.is_vector_env = False - - if self.is_vector_env: - self.obs_rms = RunningMeanStd(shape=self.single_observation_space.shape) - else: - self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) - self.epsilon = epsilon - - def step(self, action): - """Steps through the environment and normalizes the observation.""" - obs, rews, terminateds, truncateds, infos = self.env.step(action) - if self.is_vector_env: - obs = self.normalize(obs) - else: - obs = self.normalize(np.array([obs]))[0] - return obs, rews, terminateds, truncateds, infos - - def reset(self, **kwargs): - """Resets the environment and normalizes the observation.""" - obs, info = self.env.reset(**kwargs) - - if self.is_vector_env: - return self.normalize(obs), info - else: - return self.normalize(np.array([obs]))[0], info - - def normalize(self, obs): - """Normalises the observation using the running mean and variance of the observations.""" - self.obs_rms.update(obs) - return (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon) - - -class NormalizeReward(gym.core.Wrapper, gym.utils.RecordConstructorArgs): - r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. - - The exponential moving average will have variance :math:`(1 - \gamma)^2`. - - Note: - The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly - instantiated or the policy was changed recently. - """ - - def __init__( - self, - env: gym.Env, - gamma: float = 0.99, - epsilon: float = 1e-8, - ): - """This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. - - Args: - env (env): The environment to apply the wrapper - epsilon (float): A stability parameter - gamma (float): The discount factor that is used in the exponential moving average. - """ - gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon) - gym.Wrapper.__init__(self, env) - - try: - self.num_envs = self.get_wrapper_attr("num_envs") - self.is_vector_env = self.get_wrapper_attr("is_vector_env") - except AttributeError: - self.num_envs = 1 - self.is_vector_env = False - - self.return_rms = RunningMeanStd(shape=()) - self.returns = np.zeros(self.num_envs) - self.gamma = gamma - self.epsilon = epsilon - - def step(self, action): - """Steps through the environment, normalizing the rewards returned.""" - obs, rews, terminateds, truncateds, infos = self.env.step(action) - if not self.is_vector_env: - rews = np.array([rews]) - self.returns = self.returns * self.gamma * (1 - terminateds) + rews - rews = self.normalize(rews) - if not self.is_vector_env: - rews = rews[0] - return obs, rews, terminateds, truncateds, infos - - def normalize(self, rews): - """Normalizes the rewards with the running mean rewards and their variance.""" - self.return_rms.update(self.returns) - return rews / np.sqrt(self.return_rms.var + self.epsilon) diff --git a/gymnasium/experimental/wrappers/numpy_to_torch.py b/gymnasium/wrappers/numpy_to_torch.py similarity index 82% rename from gymnasium/experimental/wrappers/numpy_to_torch.py rename to gymnasium/wrappers/numpy_to_torch.py index 181eb9120..d3d00cc9e 100644 --- a/gymnasium/experimental/wrappers/numpy_to_torch.py +++ b/gymnasium/wrappers/numpy_to_torch.py @@ -23,7 +23,7 @@ except ImportError: ) -__all__ = ["NumpyToTorchV0", "torch_to_numpy", "numpy_to_torch"] +__all__ = ["NumpyToTorch", "torch_to_numpy", "numpy_to_torch"] @functools.singledispatch @@ -61,6 +61,7 @@ def numpy_to_torch(value: Any, device: Device | None = None) -> Any: ) +@numpy_to_torch.register(numbers.Number) @numpy_to_torch.register(np.ndarray) def _numpy_to_torch(value: np.ndarray, device: Device | None = None) -> torch.Tensor: """Converts a Jax Array into a PyTorch Tensor.""" @@ -84,16 +85,39 @@ def _numpy_iterable_to_torch( value: Iterable[Any], device: Device | None = None ) -> Iterable[Any]: """Converts an Iterable from Jax Array to an iterable of PyTorch Tensors.""" - return type(value)(numpy_to_torch(v, device) for v in value) + return type(value)(tuple(numpy_to_torch(v, device) for v in value)) -class NumpyToTorchV0(gym.Wrapper, gym.utils.RecordConstructorArgs): - """Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors. +class NumpyToTorch(gym.Wrapper, gym.utils.RecordConstructorArgs): + """Wraps a NumPy-based environment such that it can be interacted with PyTorch Tensors. Actions must be provided as PyTorch Tensors and observations will be returned as PyTorch Tensors. + A vector version of the wrapper exists, :class:`gymnasium.wrappers.vector.NumpyToTorch`. Note: For ``rendered`` this is returned as a NumPy array not a pytorch Tensor. + + Example: + >>> import torch + >>> import gymnasium as gym + >>> env = gym.make("CartPole-v1") + >>> env = NumpyToTorch(env) + >>> obs, _ = env.reset(seed=123) + >>> type(obs) + + >>> action = torch.tensor(env.action_space.sample()) + >>> obs, reward, terminated, truncated, info = env.step(action) + >>> type(obs) + + >>> type(reward) + + >>> type(terminated) + + >>> type(truncated) + + + Change logs: + * v1.0.0 - Initially added """ def __init__(self, env: gym.Env, device: Device | None = None): diff --git a/gymnasium/wrappers/order_enforcing.py b/gymnasium/wrappers/order_enforcing.py deleted file mode 100644 index abfb1be0f..000000000 --- a/gymnasium/wrappers/order_enforcing.py +++ /dev/null @@ -1,89 +0,0 @@ -"""Wrapper to enforce the proper ordering of environment operations.""" -from __future__ import annotations - -from copy import deepcopy -from typing import TYPE_CHECKING - -import gymnasium as gym -from gymnasium.error import ResetNeeded - - -if TYPE_CHECKING: - from gymnasium.envs.registration import EnvSpec - - -class OrderEnforcing(gym.Wrapper, gym.utils.RecordConstructorArgs): - """A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`. - - Example: - >>> import gymnasium as gym - >>> from gymnasium.wrappers import OrderEnforcing - >>> env = gym.make("CartPole-v1", render_mode="human") - >>> env = OrderEnforcing(env) - >>> env.step(0) - Traceback (most recent call last): - ... - gymnasium.error.ResetNeeded: Cannot call env.step() before calling env.reset() - >>> env.render() - Traceback (most recent call last): - ... - gymnasium.error.ResetNeeded: Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper. - >>> _ = env.reset() - >>> env.render() - >>> _ = env.step(0) - >>> env.close() - """ - - def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False): - """A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`. - - Args: - env: The environment to wrap - disable_render_order_enforcing: If to disable render order enforcing - """ - gym.utils.RecordConstructorArgs.__init__( - self, disable_render_order_enforcing=disable_render_order_enforcing - ) - gym.Wrapper.__init__(self, env) - - self._has_reset: bool = False - self._disable_render_order_enforcing: bool = disable_render_order_enforcing - - def step(self, action): - """Steps through the environment with `kwargs`.""" - if not self._has_reset: - raise ResetNeeded("Cannot call env.step() before calling env.reset()") - return self.env.step(action) - - def reset(self, **kwargs): - """Resets the environment with `kwargs`.""" - self._has_reset = True - return self.env.reset(**kwargs) - - def render(self, *args, **kwargs): - """Renders the environment with `kwargs`.""" - if not self._disable_render_order_enforcing and not self._has_reset: - raise ResetNeeded( - "Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, " - "set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper." - ) - return self.env.render(*args, **kwargs) - - @property - def has_reset(self): - """Returns if the environment has been reset before.""" - return self._has_reset - - @property - def spec(self) -> EnvSpec | None: - """Modifies the environment spec to add the `order_enforce=True`.""" - if self._cached_spec is not None: - return self._cached_spec - - env_spec = self.env.spec - if env_spec is not None: - env_spec = deepcopy(env_spec) - env_spec.order_enforce = True - - self._cached_spec = env_spec - return env_spec diff --git a/gymnasium/wrappers/pixel_observation.py b/gymnasium/wrappers/pixel_observation.py deleted file mode 100644 index 728af6eb7..000000000 --- a/gymnasium/wrappers/pixel_observation.py +++ /dev/null @@ -1,215 +0,0 @@ -"""Wrapper for augmenting observations by pixel values.""" -import collections -import copy -from collections.abc import MutableMapping -from typing import Any, Dict, List, Optional, Tuple - -import numpy as np - -import gymnasium as gym -from gymnasium import spaces - - -STATE_KEY = "state" - - -class PixelObservationWrapper(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): - """Augment observations by pixel values. - - Observations of this wrapper will be dictionaries of images. - You can also choose to add the observation of the base environment to this dictionary. - In that case, if the base environment has an observation space of type :class:`Dict`, the dictionary - of rendered images will be updated with the base environment's observation. If, however, the observation - space is of type :class:`Box`, the base environment's observation (which will be an element of the :class:`Box` - space) will be added to the dictionary under the key "state". - - Example: - >>> import gymnasium as gym - >>> from gymnasium.wrappers import PixelObservationWrapper - >>> env = PixelObservationWrapper(gym.make("CarRacing-v2", render_mode="rgb_array")) - >>> obs, _ = env.reset() - >>> obs.keys() - odict_keys(['pixels']) - >>> obs['pixels'].shape - (400, 600, 3) - >>> env = PixelObservationWrapper(gym.make("CarRacing-v2", render_mode="rgb_array"), pixels_only=False) - >>> obs, _ = env.reset() - >>> obs.keys() - odict_keys(['state', 'pixels']) - >>> obs['state'].shape - (96, 96, 3) - >>> obs['pixels'].shape - (400, 600, 3) - >>> env = PixelObservationWrapper(gym.make("CarRacing-v2", render_mode="rgb_array"), pixel_keys=('obs',)) - >>> obs, _ = env.reset() - >>> obs.keys() - odict_keys(['obs']) - >>> obs['obs'].shape - (400, 600, 3) - """ - - def __init__( - self, - env: gym.Env, - pixels_only: bool = True, - render_kwargs: Optional[Dict[str, Dict[str, Any]]] = None, - pixel_keys: Tuple[str, ...] = ("pixels",), - ): - """Initializes a new pixel Wrapper. - - Args: - env: The environment to wrap. - pixels_only (bool): If `True` (default), the original observation returned - by the wrapped environment will be discarded, and a dictionary - observation will only include pixels. If `False`, the - observation dictionary will contain both the original - observations and the pixel observations. - render_kwargs (dict): Optional dictionary containing that maps elements of `pixel_keys` to - keyword arguments passed to the :meth:`self.render` method. - pixel_keys: Optional custom string specifying the pixel - observation's key in the `OrderedDict` of observations. - Defaults to `(pixels,)`. - - Raises: - AssertionError: If any of the keys in ``render_kwargs``do not show up in ``pixel_keys``. - ValueError: If ``env``'s observation space is not compatible with the - wrapper. Supported formats are a single array, or a dict of - arrays. - ValueError: If ``env``'s observation already contains any of the - specified ``pixel_keys``. - TypeError: When an unexpected pixel type is used - """ - gym.utils.RecordConstructorArgs.__init__( - self, - pixels_only=pixels_only, - render_kwargs=render_kwargs, - pixel_keys=pixel_keys, - ) - gym.ObservationWrapper.__init__(self, env) - - # Avoid side-effects that occur when render_kwargs is manipulated - render_kwargs = copy.deepcopy(render_kwargs) - self.render_history = [] - - if render_kwargs is None: - render_kwargs = {} - - for key in render_kwargs: - assert key in pixel_keys, ( - "The argument render_kwargs should map elements of " - "pixel_keys to dictionaries of keyword arguments. " - f"Found key '{key}' in render_kwargs but not in pixel_keys." - ) - - default_render_kwargs = {} - if not env.render_mode: - raise AttributeError( - "env.render_mode must be specified to use PixelObservationWrapper:" - "`gymnasium.make(env_name, render_mode='rgb_array')`." - ) - - for key in pixel_keys: - render_kwargs.setdefault(key, default_render_kwargs) - - wrapped_observation_space = env.observation_space - - if isinstance(wrapped_observation_space, spaces.Box): - self._observation_is_dict = False - invalid_keys = {STATE_KEY} - elif isinstance(wrapped_observation_space, (spaces.Dict, MutableMapping)): - self._observation_is_dict = True - invalid_keys = set(wrapped_observation_space.spaces.keys()) - else: - raise ValueError("Unsupported observation space structure.") - - if not pixels_only: - # Make sure that now keys in the `pixel_keys` overlap with - # `observation_keys` - overlapping_keys = set(pixel_keys) & set(invalid_keys) - if overlapping_keys: - raise ValueError( - f"Duplicate or reserved pixel keys {overlapping_keys!r}." - ) - - if pixels_only: - self.observation_space = spaces.Dict() - elif self._observation_is_dict: - self.observation_space = copy.deepcopy(wrapped_observation_space) - else: - self.observation_space = spaces.Dict({STATE_KEY: wrapped_observation_space}) - - # Extend observation space with pixels. - - self.env.reset() - pixels_spaces = {} - for pixel_key in pixel_keys: - pixels = self._render(**render_kwargs[pixel_key]) - pixels: np.ndarray = pixels[-1] if isinstance(pixels, List) else pixels - - if not hasattr(pixels, "dtype") or not hasattr(pixels, "shape"): - raise TypeError( - f"Render method returns a {pixels.__class__.__name__}, but an array with dtype and shape is expected." - "Be sure to specify the correct render_mode." - ) - - if np.issubdtype(pixels.dtype, np.integer): - low, high = (0, 255) - elif np.issubdtype(pixels.dtype, np.float): - low, high = (-float("inf"), float("inf")) - else: - raise TypeError(pixels.dtype) - - pixels_space = spaces.Box( - shape=pixels.shape, low=low, high=high, dtype=pixels.dtype - ) - pixels_spaces[pixel_key] = pixels_space - - self.observation_space.spaces.update(pixels_spaces) - - self._pixels_only = pixels_only - self._render_kwargs = render_kwargs - self._pixel_keys = pixel_keys - - def observation(self, observation): - """Updates the observations with the pixel observations. - - Args: - observation: The observation to add pixel observations for - - Returns: - The updated pixel observations - """ - pixel_observation = self._add_pixel_observation(observation) - return pixel_observation - - def _add_pixel_observation(self, wrapped_observation): - if self._pixels_only: - observation = collections.OrderedDict() - elif self._observation_is_dict: - observation = type(wrapped_observation)(wrapped_observation) - else: - observation = collections.OrderedDict() - observation[STATE_KEY] = wrapped_observation - - pixel_observations = { - pixel_key: self._render(**self._render_kwargs[pixel_key]) - for pixel_key in self._pixel_keys - } - - observation.update(pixel_observations) - - return observation - - def render(self, *args, **kwargs): - """Renders the environment.""" - render = self.env.render(*args, **kwargs) - if isinstance(render, list): - render = self.render_history + render - self.render_history = [] - return render - - def _render(self, *args, **kwargs): - render = self.env.render(*args, **kwargs) - if isinstance(render, list): - self.render_history += render - return render diff --git a/gymnasium/wrappers/record_episode_statistics.py b/gymnasium/wrappers/record_episode_statistics.py deleted file mode 100644 index 1abe4c78b..000000000 --- a/gymnasium/wrappers/record_episode_statistics.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Wrapper that tracks the cumulative rewards and episode lengths.""" -import time -from collections import deque -from typing import Optional - -import numpy as np - -import gymnasium as gym - - -class RecordEpisodeStatistics(gym.Wrapper, gym.utils.RecordConstructorArgs): - """This wrapper will keep track of cumulative rewards and episode lengths. - - At the end of an episode, the statistics of the episode will be added to ``info`` - using the key ``episode``. If using a vectorized environment also the key - ``_episode`` is used which indicates whether the env at the respective index has - the episode statistics. - - After the completion of an episode, ``info`` will look like this:: - - >>> info = { - ... "episode": { - ... "r": "", - ... "l": "", - ... "t": "" - ... }, - ... } - - For a vectorized environments the output will be in the form of:: - - >>> infos = { - ... "final_observation": "", - ... "_final_observation": "", - ... "final_info": "", - ... "_final_info": "", - ... "episode": { - ... "r": "", - ... "l": "", - ... "t": "" - ... }, - ... "_episode": "" - ... } - - Moreover, the most recent rewards and episode lengths are stored in buffers that can be accessed via - :attr:`wrapped_env.return_queue` and :attr:`wrapped_env.length_queue` respectively. - - Attributes: - return_queue: The cumulative rewards of the last ``deque_size``-many episodes - length_queue: The lengths of the last ``deque_size``-many episodes - """ - - def __init__(self, env: gym.Env, deque_size: int = 100): - """This wrapper will keep track of cumulative rewards and episode lengths. - - Args: - env (Env): The environment to apply the wrapper - deque_size: The size of the buffers :attr:`return_queue` and :attr:`length_queue` - """ - gym.utils.RecordConstructorArgs.__init__(self, deque_size=deque_size) - gym.Wrapper.__init__(self, env) - - try: - self.num_envs = self.get_wrapper_attr("num_envs") - self.is_vector_env = self.get_wrapper_attr("is_vector_env") - except AttributeError: - self.num_envs = 1 - self.is_vector_env = False - - self.episode_count = 0 - self.episode_start_times: np.ndarray = None - self.episode_returns: Optional[np.ndarray] = None - self.episode_lengths: Optional[np.ndarray] = None - self.return_queue = deque(maxlen=deque_size) - self.length_queue = deque(maxlen=deque_size) - - def reset(self, **kwargs): - """Resets the environment using kwargs and resets the episode returns and lengths.""" - obs, info = super().reset(**kwargs) - self.episode_start_times = np.full( - self.num_envs, time.perf_counter(), dtype=np.float32 - ) - self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) - self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) - return obs, info - - def step(self, action): - """Steps through the environment, recording the episode statistics.""" - ( - observations, - rewards, - terminations, - truncations, - infos, - ) = self.env.step(action) - assert isinstance( - infos, dict - ), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order." - self.episode_returns += rewards - self.episode_lengths += 1 - dones = np.logical_or(terminations, truncations) - num_dones = np.sum(dones) - if num_dones: - if "episode" in infos or "_episode" in infos: - raise ValueError( - "Attempted to add episode stats when they already exist" - ) - else: - infos["episode"] = { - "r": np.where(dones, self.episode_returns, 0.0), - "l": np.where(dones, self.episode_lengths, 0), - "t": np.where( - dones, - np.round(time.perf_counter() - self.episode_start_times, 6), - 0.0, - ), - } - if self.is_vector_env: - infos["_episode"] = np.where(dones, True, False) - self.return_queue.extend(self.episode_returns[dones]) - self.length_queue.extend(self.episode_lengths[dones]) - self.episode_count += num_dones - self.episode_lengths[dones] = 0 - self.episode_returns[dones] = 0 - self.episode_start_times[dones] = time.perf_counter() - return ( - observations, - rewards, - terminations, - truncations, - infos, - ) diff --git a/gymnasium/wrappers/record_video.py b/gymnasium/wrappers/record_video.py deleted file mode 100644 index 7c43215e6..000000000 --- a/gymnasium/wrappers/record_video.py +++ /dev/null @@ -1,228 +0,0 @@ -"""Wrapper for recording videos.""" -import os -from typing import Callable, Optional - -import gymnasium as gym -from gymnasium import logger -from gymnasium.wrappers.monitoring import video_recorder - - -def capped_cubic_video_schedule(episode_id: int) -> bool: - """The default episode trigger. - - This function will trigger recordings at the episode indices 0, 1, 8, 27, ..., :math:`k^3`, ..., 729, 1000, 2000, 3000, ... - - Args: - episode_id: The episode number - - Returns: - If to apply a video schedule number - """ - if episode_id < 1000: - return int(round(episode_id ** (1.0 / 3))) ** 3 == episode_id - else: - return episode_id % 1000 == 0 - - -class RecordVideo(gym.Wrapper, gym.utils.RecordConstructorArgs): - """This wrapper records videos of rollouts. - - Usually, you only want to record episodes intermittently, say every hundredth episode. - To do this, you can specify **either** ``episode_trigger`` **or** ``step_trigger`` (not both). - They should be functions returning a boolean that indicates whether a recording should be started at the - current episode or step, respectively. - If neither :attr:`episode_trigger` nor ``step_trigger`` is passed, a default ``episode_trigger`` will be employed. - By default, the recording will be stopped once a `terminated` or `truncated` signal has been emitted by the environment. However, you can - also create recordings of fixed length (possibly spanning several episodes) by passing a strictly positive value for - ``video_length``. - """ - - def __init__( - self, - env: gym.Env, - video_folder: str, - episode_trigger: Callable[[int], bool] = None, - step_trigger: Callable[[int], bool] = None, - video_length: int = 0, - name_prefix: str = "rl-video", - disable_logger: bool = False, - ): - """Wrapper records videos of rollouts. - - Args: - env: The environment that will be wrapped - video_folder (str): The folder where the recordings will be stored - episode_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this episode - step_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this step - video_length (int): The length of recorded episodes. If 0, entire episodes are recorded. - Otherwise, snippets of the specified length are captured - name_prefix (str): Will be prepended to the filename of the recordings - disable_logger (bool): Whether to disable moviepy logger or not. - """ - gym.utils.RecordConstructorArgs.__init__( - self, - video_folder=video_folder, - episode_trigger=episode_trigger, - step_trigger=step_trigger, - video_length=video_length, - name_prefix=name_prefix, - disable_logger=disable_logger, - ) - gym.Wrapper.__init__(self, env) - - if env.render_mode in {None, "human", "ansi", "ansi_list"}: - raise ValueError( - f"Render mode is {env.render_mode}, which is incompatible with" - f" RecordVideo. Initialize your environment with a render_mode" - f" that returns an image, such as rgb_array." - ) - - if episode_trigger is None and step_trigger is None: - episode_trigger = capped_cubic_video_schedule - - trigger_count = sum(x is not None for x in [episode_trigger, step_trigger]) - assert trigger_count == 1, "Must specify exactly one trigger" - - self.episode_trigger = episode_trigger - self.step_trigger = step_trigger - self.video_recorder: Optional[video_recorder.VideoRecorder] = None - self.disable_logger = disable_logger - - self.video_folder = os.path.abspath(video_folder) - # Create output folder if needed - if os.path.isdir(self.video_folder): - logger.warn( - f"Overwriting existing videos at {self.video_folder} folder " - f"(try specifying a different `video_folder` for the `RecordVideo` wrapper if this is not desired)" - ) - os.makedirs(self.video_folder, exist_ok=True) - - self.name_prefix = name_prefix - self.step_id = 0 - self.video_length = video_length - - self.recording = False - self.terminated = False - self.truncated = False - self.recorded_frames = 0 - self.episode_id = 0 - - try: - self.is_vector_env = self.get_wrapper_attr("is_vector_env") - except AttributeError: - self.is_vector_env = False - - def reset(self, **kwargs): - """Reset the environment using kwargs and then starts recording if video enabled.""" - observations = super().reset(**kwargs) - self.terminated = False - self.truncated = False - if self.recording: - assert self.video_recorder is not None - self.video_recorder.recorded_frames = [] - self.video_recorder.capture_frame() - self.recorded_frames += 1 - if self.video_length > 0: - if self.recorded_frames > self.video_length: - self.close_video_recorder() - elif self._video_enabled(): - self.start_video_recorder() - return observations - - def start_video_recorder(self): - """Starts video recorder using :class:`video_recorder.VideoRecorder`.""" - self.close_video_recorder() - - video_name = f"{self.name_prefix}-step-{self.step_id}" - if self.episode_trigger: - video_name = f"{self.name_prefix}-episode-{self.episode_id}" - - base_path = os.path.join(self.video_folder, video_name) - self.video_recorder = video_recorder.VideoRecorder( - env=self.env, - base_path=base_path, - metadata={"step_id": self.step_id, "episode_id": self.episode_id}, - disable_logger=self.disable_logger, - ) - - self.video_recorder.capture_frame() - self.recorded_frames = 1 - self.recording = True - - def _video_enabled(self): - if self.step_trigger: - return self.step_trigger(self.step_id) - else: - return self.episode_trigger(self.episode_id) - - def step(self, action): - """Steps through the environment using action, recording observations if :attr:`self.recording`.""" - ( - observations, - rewards, - terminateds, - truncateds, - infos, - ) = self.env.step(action) - - if not (self.terminated or self.truncated): - # increment steps and episodes - self.step_id += 1 - if not self.is_vector_env: - if terminateds or truncateds: - self.episode_id += 1 - self.terminated = terminateds - self.truncated = truncateds - elif terminateds[0] or truncateds[0]: - self.episode_id += 1 - self.terminated = terminateds[0] - self.truncated = truncateds[0] - - if self.recording: - assert self.video_recorder is not None - self.video_recorder.capture_frame() - self.recorded_frames += 1 - if self.video_length > 0: - if self.recorded_frames > self.video_length: - self.close_video_recorder() - else: - if not self.is_vector_env: - if terminateds or truncateds: - self.close_video_recorder() - elif terminateds[0] or truncateds[0]: - self.close_video_recorder() - - elif self._video_enabled(): - self.start_video_recorder() - - return observations, rewards, terminateds, truncateds, infos - - def close_video_recorder(self): - """Closes the video recorder if currently recording.""" - if self.recording: - assert self.video_recorder is not None - self.video_recorder.close() - self.recording = False - self.recorded_frames = 1 - - def render(self, *args, **kwargs): - """Compute the render frames as specified by render_mode attribute during initialization of the environment or as specified in kwargs.""" - if self.video_recorder is None or not self.video_recorder.enabled: - return super().render(*args, **kwargs) - - if len(self.video_recorder.render_history) > 0: - recorded_frames = [ - self.video_recorder.render_history.pop() - for _ in range(len(self.video_recorder.render_history)) - ] - if self.recording: - return recorded_frames - else: - return recorded_frames + super().render(*args, **kwargs) - else: - return super().render(*args, **kwargs) - - def close(self): - """Closes the wrapper then the video recorder.""" - super().close() - self.close_video_recorder() diff --git a/gymnasium/wrappers/render_collection.py b/gymnasium/wrappers/render_collection.py deleted file mode 100644 index 254006e2a..000000000 --- a/gymnasium/wrappers/render_collection.py +++ /dev/null @@ -1,62 +0,0 @@ -"""A wrapper that adds render collection mode to an environment.""" -import copy - -import gymnasium as gym - - -class RenderCollection(gym.Wrapper, gym.utils.RecordConstructorArgs): - """Save collection of render frames.""" - - def __init__(self, env: gym.Env, pop_frames: bool = True, reset_clean: bool = True): - """Initialize a :class:`RenderCollection` instance. - - Args: - env: The environment that is being wrapped - pop_frames (bool): If true, clear the collection frames after .render() is called. - Default value is True. - reset_clean (bool): If true, clear the collection frames when .reset() is called. - Default value is True. - """ - gym.utils.RecordConstructorArgs.__init__( - self, pop_frames=pop_frames, reset_clean=reset_clean - ) - gym.Wrapper.__init__(self, env) - - assert env.render_mode is not None - assert not env.render_mode.endswith("_list") - self.frame_list = [] - self.reset_clean = reset_clean - self.pop_frames = pop_frames - - self.metadata = copy.deepcopy(self.env.metadata) - if f"{self.env.render_mode}_list" not in self.metadata["render_modes"]: - self.metadata["render_modes"].append(f"{self.env.render_mode}_list") - - @property - def render_mode(self): - """Returns the collection render_mode name.""" - return f"{self.env.render_mode}_list" - - def step(self, *args, **kwargs): - """Perform a step in the base environment and collect a frame.""" - output = self.env.step(*args, **kwargs) - self.frame_list.append(self.env.render()) - return output - - def reset(self, *args, **kwargs): - """Reset the base environment, eventually clear the frame_list, and collect a frame.""" - result = self.env.reset(*args, **kwargs) - - if self.reset_clean: - self.frame_list = [] - self.frame_list.append(self.env.render()) - - return result - - def render(self): - """Returns the collection of frames and, if pop_frames = True, clears it.""" - frames = self.frame_list - if self.pop_frames: - self.frame_list = [] - - return frames diff --git a/gymnasium/experimental/wrappers/rendering.py b/gymnasium/wrappers/rendering.py similarity index 72% rename from gymnasium/experimental/wrappers/rendering.py rename to gymnasium/wrappers/rendering.py index b502945b4..6fc1e451d 100644 --- a/gymnasium/experimental/wrappers/rendering.py +++ b/gymnasium/wrappers/rendering.py @@ -1,8 +1,8 @@ """A collections of rendering-based wrappers. -* ``RenderCollectionV0`` - Collects rendered frames into a list -* ``RecordVideoV0`` - Records a video of the environments -* ``HumanRenderingV0`` - Provides human rendering of environments with ``"rgb_array"`` +* ``RenderCollection`` - Collects rendered frames into a list +* ``RecordVideo`` - Records a video of the environments +* ``HumanRendering`` - Provides human rendering of environments with ``"rgb_array"`` """ from __future__ import annotations @@ -18,13 +18,76 @@ from gymnasium.core import ActType, ObsType, RenderFrame from gymnasium.error import DependencyNotInstalled -__all__ = ["RenderCollectionV0", "RecordVideoV0", "HumanRenderingV0"] +__all__ = [ + "RenderCollection", + "RecordVideo", + "HumanRendering", +] -class RenderCollectionV0( +class RenderCollection( gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs ): - """Collect rendered frames of an environment such ``render`` returns a ``list[RenderedFrame]``.""" + """Collect rendered frames of an environment such ``render`` returns a ``list[RenderedFrame]``. + + No vector version of the wrapper exists. + + Example: + Return the list of frames for the number of steps ``render`` wasn't called. + >>> import gymnasium as gym + >>> env = gym.make("LunarLander-v2", render_mode="rgb_array") + >>> env = RenderCollection(env) + >>> _ = env.reset(seed=123) + >>> for _ in range(5): + ... _ = env.step(env.action_space.sample()) + ... + >>> frames = env.render() + >>> len(frames) + 6 + + >>> frames = env.render() + >>> len(frames) + 0 + + Return the list of frames for the number of steps the episode was running. + >>> import gymnasium as gym + >>> env = gym.make("LunarLander-v2", render_mode="rgb_array") + >>> env = RenderCollection(env, pop_frames=False) + >>> _ = env.reset(seed=123) + >>> for _ in range(5): + ... _ = env.step(env.action_space.sample()) + ... + >>> frames = env.render() + >>> len(frames) + 6 + + >>> frames = env.render() + >>> len(frames) + 6 + + Collect all frames for all episodes, without clearing them when render is called + >>> import gymnasium as gym + >>> env = gym.make("LunarLander-v2", render_mode="rgb_array") + >>> env = RenderCollection(env, pop_frames=False, reset_clean=False) + >>> _ = env.reset(seed=123) + >>> for _ in range(5): + ... _ = env.step(env.action_space.sample()) + ... + >>> _ = env.reset(seed=123) + >>> for _ in range(5): + ... _ = env.step(env.action_space.sample()) + ... + >>> frames = env.render() + >>> len(frames) + 12 + + >>> frames = env.render() + >>> len(frames) + 12 + + Change logs: + * v0.26.2 - Initially added + """ def __init__( self, @@ -89,20 +152,78 @@ class RenderCollectionV0( return frames -class RecordVideoV0( +class RecordVideo( gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs ): - """This wrapper records videos of rollouts. + """Records videos of environment episodes using the environment's render function. - Usually, you only want to record episodes intermittently, say every hundredth episode. + .. py:currentmodule:: gymnasium.utils.save_video + + Usually, you only want to record episodes intermittently, say every hundredth episode or at every thousandth environment step. To do this, you can specify ``episode_trigger`` or ``step_trigger``. They should be functions returning a boolean that indicates whether a recording should be started at the current episode or step, respectively. - If neither :attr:`episode_trigger` nor ``step_trigger`` is passed, a default ``episode_trigger`` will be employed, - i.e. capped_cubic_video_schedule. This function starts a video at every episode that is a power of 3 until 1000 and - then every 1000 episodes. - By default, the recording will be stopped once reset is called. However, you can also create recordings of fixed - length (possibly spanning several episodes) by passing a strictly positive value for ``video_length``. + + The ``episode_trigger`` should return ``True`` on the episode when recording should start. + The ``step_trigger`` should return ``True`` on the n-th environment step that the recording should be started, where n sums over all previous episodes. + If neither :attr:`episode_trigger` nor ``step_trigger`` is passed, a default ``episode_trigger`` will be employed, i.e. :func:`capped_cubic_video_schedule`. + This function starts a video at every episode that is a power of 3 until 1000 and then every 1000 episodes. + By default, the recording will be stopped once reset is called. + However, you can also create recordings of fixed length (possibly spanning several episodes) + by passing a strictly positive value for ``video_length``. + + No vector version of the wrapper exists. + + Examples - Run the environment for 50 episodes, and save the video every 10 episodes starting from the 0th: + >>> import os + >>> import gymnasium as gym + >>> env = gym.make("LunarLander-v2", render_mode="rgb_array") + >>> trigger = lambda t: t % 10 == 0 + >>> env = RecordVideo(env, video_folder="./save_videos1", episode_trigger=trigger, disable_logger=True) + >>> for i in range(50): + ... termination, truncation = False, False + ... _ = env.reset(seed=123) + ... while not (termination or truncation): + ... obs, rew, termination, truncation, info = env.step(env.action_space.sample()) + ... + >>> env.close() + >>> len(os.listdir("./save_videos1")) + 5 + + Examples - Run the environment for 5 episodes, start a recording every 200th step, making sure each video is 100 frames long: + >>> import os + >>> import gymnasium as gym + >>> env = gym.make("LunarLander-v2", render_mode="rgb_array") + >>> trigger = lambda t: t % 200 == 0 + >>> env = RecordVideo(env, video_folder="./save_videos2", step_trigger=trigger, video_length=100, disable_logger=True) + >>> for i in range(5): + ... termination, truncation = False, False + ... _ = env.reset(seed=123) + ... _ = env.action_space.seed(123) + ... while not (termination or truncation): + ... obs, rew, termination, truncation, info = env.step(env.action_space.sample()) + ... + >>> env.close() + >>> len(os.listdir("./save_videos2")) + 2 + + Examples - Run 3 episodes, record everything, but in chunks of 1000 frames: + >>> import os + >>> import gymnasium as gym + >>> env = gym.make("LunarLander-v2", render_mode="rgb_array") + >>> env = RecordVideo(env, video_folder="./save_videos3", video_length=1000, disable_logger=True) + >>> for i in range(3): + ... termination, truncation = False, False + ... _ = env.reset(seed=123) + ... while not (termination or truncation): + ... obs, rew, termination, truncation, info = env.step(env.action_space.sample()) + ... + >>> env.close() + >>> len(os.listdir("./save_videos3")) + 2 + + Change logs: + * v0.25.0 - Initially added to replace ``wrappers.monitoring.VideoRecorder`` """ def __init__( @@ -127,7 +248,7 @@ class RecordVideoV0( Otherwise, snippets of the specified length are captured name_prefix (str): Will be prepended to the filename of the recordings fps (int): The frame per second in the video. The default value is the one specified in the environment metadata. - If the environment metadata doesn't specify `render_fps`, the value 30 is used. + If the environment metadata doesn't specify ``render_fps``, the value 30 is used. disable_logger (bool): Whether to disable moviepy logger or not """ gym.utils.RecordConstructorArgs.__init__( @@ -148,12 +269,7 @@ class RecordVideoV0( ) if episode_trigger is None and step_trigger is None: - - def capped_cubic_video_schedule(episode_id: int) -> bool: - if episode_id < 1000: - return int(round(episode_id ** (1.0 / 3))) ** 3 == episode_id - else: - return episode_id % 1000 == 0 + from gymnasium.utils.save_video import capped_cubic_video_schedule episode_trigger = capped_cubic_video_schedule @@ -244,6 +360,25 @@ class RecordVideoV0( return obs, rew, terminated, truncated, info + def render(self) -> RenderFrame | list[RenderFrame]: + """Compute the render frames as specified by render_mode attribute during initialization of the environment.""" + render_out = super().render() + if self.recording and isinstance(render_out, List): + self.recorded_frames += render_out + + if len(self.render_history) > 0: + tmp_history = self.render_history + self.render_history = [] + return tmp_history + render_out + else: + return render_out + + def close(self): + """Closes the wrapper then the video recorder.""" + super().close() + if self.recording: + self.stop_recording() + def start_recording(self, video_name: str): """Start a new recording. If it is already recording, stops the current recording before starting the new one.""" if self.recording: @@ -275,35 +410,16 @@ class RecordVideoV0( self.recording = False self._video_name = None - def render(self) -> RenderFrame | list[RenderFrame]: - """Compute the render frames as specified by render_mode attribute during initialization of the environment.""" - render_out = super().render() - if self.recording and isinstance(render_out, List): - self.recorded_frames += render_out - - if len(self.render_history) > 0: - tmp_history = self.render_history - self.render_history = [] - return tmp_history + render_out - else: - return render_out - - def close(self): - """Closes the wrapper then the video recorder.""" - super().close() - if self.recording: - self.stop_recording() - def __del__(self): """Warn the user in case last video wasn't saved.""" if len(self.recorded_frames) > 0: logger.warn("Unable to save last video! Did you call close()?") -class HumanRenderingV0( +class HumanRendering( gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs ): - """Performs human rendering for an environment that only supports "rgb_array"rendering. + """Allows human like rendering for environments that support "rgb_array" rendering. This wrapper is particularly useful when you have implemented an environment that can produce RGB images but haven't implemented any code to render the images to the screen. @@ -312,11 +428,13 @@ class HumanRenderingV0( The ``render_mode`` of the wrapped environment must be either ``'rgb_array'`` or ``'rgb_array_list'``. + No vector version of the wrapper exists. + Example: >>> import gymnasium as gym - >>> from gymnasium.experimental.wrappers import HumanRenderingV0 + >>> from gymnasium.wrappers import HumanRendering >>> env = gym.make("LunarLander-v2", render_mode="rgb_array") - >>> wrapped = HumanRenderingV0(env) + >>> wrapped = HumanRendering(env) >>> obs, _ = wrapped.reset() # This will start rendering to the screen The wrapper can also be applied directly when the environment is instantiated, simply by passing @@ -330,10 +448,13 @@ class HumanRenderingV0( will always return an empty list: >>> env = gym.make("LunarLander-v2", render_mode="rgb_array_list") - >>> wrapped = HumanRenderingV0(env) + >>> wrapped = HumanRendering(env) >>> obs, _ = wrapped.reset() >>> env.render() # env.render() will always return an empty list! [] + + Change logs: + * v0.25.0 - Initially added """ def __init__(self, env: gym.Env[ObsType, ActType]): diff --git a/gymnasium/wrappers/rescale_action.py b/gymnasium/wrappers/rescale_action.py deleted file mode 100644 index 884d5bef6..000000000 --- a/gymnasium/wrappers/rescale_action.py +++ /dev/null @@ -1,89 +0,0 @@ -"""Wrapper for rescaling actions to within a max and min action.""" -from typing import Union - -import numpy as np - -import gymnasium as gym -from gymnasium.spaces import Box - - -class RescaleAction(gym.ActionWrapper, gym.utils.RecordConstructorArgs): - """Affinely rescales the continuous action space of the environment to the range [min_action, max_action]. - - The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action` - or :attr:`max_action` are numpy arrays, the shape must match the shape of the environment's action space. - - Example: - >>> import gymnasium as gym - >>> from gymnasium.wrappers import RescaleAction - >>> import numpy as np - >>> env = gym.make("Hopper-v4") - >>> _ = env.reset(seed=42) - >>> obs, _, _, _, _ = env.step(np.array([1,1,1])) - >>> _ = env.reset(seed=42) - >>> min_action = -0.5 - >>> max_action = np.array([0.0, 0.5, 0.75]) - >>> wrapped_env = RescaleAction(env, min_action=min_action, max_action=max_action) - >>> wrapped_env_obs, _, _, _, _ = wrapped_env.step(max_action) - >>> np.alltrue(obs == wrapped_env_obs) - True - """ - - def __init__( - self, - env: gym.Env, - min_action: Union[float, int, np.ndarray], - max_action: Union[float, int, np.ndarray], - ): - """Initializes the :class:`RescaleAction` wrapper. - - Args: - env (Env): The environment to apply the wrapper - min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar. - max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar. - """ - assert isinstance( - env.action_space, Box - ), f"expected Box action space, got {type(env.action_space)}" - assert np.less_equal(min_action, max_action).all(), (min_action, max_action) - - gym.utils.RecordConstructorArgs.__init__( - self, min_action=min_action, max_action=max_action - ) - gym.ActionWrapper.__init__(self, env) - - self.min_action = ( - np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action - ) - self.max_action = ( - np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + max_action - ) - - self.action_space = Box( - low=min_action, - high=max_action, - shape=env.action_space.shape, - dtype=env.action_space.dtype, - ) - - def action(self, action): - """Rescales the action affinely from [:attr:`min_action`, :attr:`max_action`] to the action space of the base environment, :attr:`env`. - - Args: - action: The action to rescale - - Returns: - The rescaled action - """ - assert np.all(np.greater_equal(action, self.min_action)), ( - action, - self.min_action, - ) - assert np.all(np.less_equal(action, self.max_action)), (action, self.max_action) - low = self.env.action_space.low - high = self.env.action_space.high - action = low + (high - low) * ( - (action - self.min_action) / (self.max_action - self.min_action) - ) - action = np.clip(action, low, high) - return action diff --git a/gymnasium/wrappers/resize_observation.py b/gymnasium/wrappers/resize_observation.py deleted file mode 100644 index 82829adbf..000000000 --- a/gymnasium/wrappers/resize_observation.py +++ /dev/null @@ -1,83 +0,0 @@ -"""Wrapper for resizing observations.""" -from __future__ import annotations - -import numpy as np - -import gymnasium as gym -from gymnasium.error import DependencyNotInstalled -from gymnasium.spaces import Box - - -class ResizeObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): - """Resize the image observation. - - This wrapper works on environments with image observations. More generally, - the input can either be two-dimensional (AxB, e.g. grayscale images) or - three-dimensional (AxBxC, e.g. color images). This resizes the observation - to the shape given by the 2-tuple :attr:`shape`. - The argument :attr:`shape` may also be an integer, in which case, the - observation is scaled to a square of side-length :attr:`shape`. - - Example: - >>> import gymnasium as gym - >>> from gymnasium.wrappers import ResizeObservation - >>> env = gym.make("CarRacing-v2") - >>> env.observation_space.shape - (96, 96, 3) - >>> env = ResizeObservation(env, 64) - >>> env.observation_space.shape - (64, 64, 3) - """ - - def __init__(self, env: gym.Env, shape: tuple[int, int] | int) -> None: - """Resizes image observations to shape given by :attr:`shape`. - - Args: - env: The environment to apply the wrapper - shape: The shape of the resized observations - """ - gym.utils.RecordConstructorArgs.__init__(self, shape=shape) - gym.ObservationWrapper.__init__(self, env) - - if isinstance(shape, int): - shape = (shape, shape) - assert len(shape) == 2 and all( - x > 0 for x in shape - ), f"Expected shape to be a 2-tuple of positive integers, got: {shape}" - - self.shape = tuple(shape) - - assert isinstance( - env.observation_space, Box - ), f"Expected the observation space to be Box, actual type: {type(env.observation_space)}" - dims = len(env.observation_space.shape) - assert ( - dims == 2 or dims == 3 - ), f"Expected the observation space to have 2 or 3 dimensions, got: {dims}" - - obs_shape = self.shape + env.observation_space.shape[2:] - self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8) - - def observation(self, observation): - """Updates the observations by resizing the observation to shape given by :attr:`shape`. - - Args: - observation: The observation to reshape - - Returns: - The reshaped observations - - Raises: - DependencyNotInstalled: opencv-python is not installed - """ - try: - import cv2 - except ImportError as e: - raise DependencyNotInstalled( - "opencv (cv2) is not installed, run `pip install gymnasium[other]`" - ) from e - - observation = cv2.resize( - observation, self.shape[::-1], interpolation=cv2.INTER_AREA - ) - return observation.reshape(self.observation_space.shape) diff --git a/gymnasium/experimental/wrappers/stateful_action.py b/gymnasium/wrappers/stateful_action.py similarity index 64% rename from gymnasium/experimental/wrappers/stateful_action.py rename to gymnasium/wrappers/stateful_action.py index c1e055cc5..edac53804 100644 --- a/gymnasium/experimental/wrappers/stateful_action.py +++ b/gymnasium/wrappers/stateful_action.py @@ -8,16 +8,36 @@ from gymnasium.core import ActType, ObsType from gymnasium.error import InvalidProbability -__all__ = ["StickyActionV0"] +__all__ = ["StickyAction"] -class StickyActionV0( +class StickyAction( gym.ActionWrapper[ObsType, ActType, ActType], gym.utils.RecordConstructorArgs ): - """Wrapper which adds a probability of repeating the previous action. + """Adds a probability that the action is repeated for the same ``step`` function. This wrapper follows the implementation proposed by `Machado et al., 2018 `_ in Section 5.2 on page 12. + + No vector version of the wrapper exists. + + Example: + >>> import gymnasium as gym + >>> env = gym.make("CartPole-v1") + >>> env = StickyAction(env, repeat_action_probability=0.9) + >>> env.reset(seed=123) + (array([ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], dtype=float32), {}) + >>> env.step(1) + (array([ 0.01734283, 0.15089367, -0.02859527, -0.33293587], dtype=float32), 1.0, False, False, {}) + >>> env.step(0) + (array([ 0.0203607 , 0.34641072, -0.03525399, -0.6344974 ], dtype=float32), 1.0, False, False, {}) + >>> env.step(1) + (array([ 0.02728892, 0.5420062 , -0.04794393, -0.9380709 ], dtype=float32), 1.0, False, False, {}) + >>> env.step(0) + (array([ 0.03812904, 0.34756234, -0.06670535, -0.6608303 ], dtype=float32), 1.0, False, False, {}) + + Change logs: + * v1.0.0 - Initially added """ def __init__( diff --git a/gymnasium/experimental/wrappers/stateful_observation.py b/gymnasium/wrappers/stateful_observation.py similarity index 73% rename from gymnasium/experimental/wrappers/stateful_observation.py rename to gymnasium/wrappers/stateful_observation.py index d121f540c..5817d1a5a 100644 --- a/gymnasium/experimental/wrappers/stateful_observation.py +++ b/gymnasium/wrappers/stateful_observation.py @@ -1,10 +1,10 @@ """A collection of stateful observation wrappers. -* ``DelayObservationV0`` - A wrapper for delaying the returned observation -* ``TimeAwareObservationV0`` - A wrapper for adding time aware observations to environment observation -* ``FrameStackObservationV0`` - Frame stack the observations -* ``NormalizeObservationV0`` - Normalized the observations to a mean and -* ``MaxAndSkipObservationV0`` - Return only every ``skip``-th frame (frameskipping) and return the max between the two last frames. +* ``DelayObservation`` - A wrapper for delaying the returned observation +* ``TimeAwareObservation`` - A wrapper for adding time aware observations to environment observation +* ``FrameStackObservation`` - Frame stack the observations +* ``NormalizeObservation`` - Normalized the observations to have unit variance with a moving mean +* ``MaxAndSkipObservation`` - Return only every ``skip``-th frame (frameskipping) and return the max between the two last frames. """ from __future__ import annotations @@ -17,38 +17,40 @@ import numpy as np import gymnasium as gym import gymnasium.spaces as spaces from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType -from gymnasium.experimental.vector.utils import ( - batch_space, - concatenate, - create_empty_array, -) -from gymnasium.experimental.wrappers.utils import RunningMeanStd, create_zero_array from gymnasium.spaces import Box, Dict, Tuple +from gymnasium.vector.utils import batch_space, concatenate, create_empty_array +from gymnasium.wrappers.utils import RunningMeanStd, create_zero_array __all__ = [ - "DelayObservationV0", - "TimeAwareObservationV0", - "FrameStackObservationV0", - "NormalizeObservationV0", + "DelayObservation", + "TimeAwareObservation", + "FrameStackObservation", + "NormalizeObservation", + "MaxAndSkipObservation", ] -class DelayObservationV0( +class DelayObservation( gym.ObservationWrapper[ObsType, ActType, ObsType], gym.utils.RecordConstructorArgs ): - """Wrapper which adds a delay to the returned observation. + """Adds a delay to the returned observation from the environment. Before reaching the :attr:`delay` number of timesteps, returned observations is an array of zeros with the same shape as the observation space. + No vector version of the wrapper exists. + + Note: + This does not support random delay values, if users are interested, please raise an issue or pull request to add this feature. + Example: >>> import gymnasium as gym >>> env = gym.make("CartPole-v1") >>> env.reset(seed=123) (array([ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], dtype=float32), {}) - >>> env = DelayObservationV0(env, delay=2) + >>> env = DelayObservation(env, delay=2) >>> env.reset(seed=123) (array([0., 0., 0., 0.], dtype=float32), {}) >>> env.step(env.action_space.sample()) @@ -56,8 +58,8 @@ class DelayObservationV0( >>> env.step(env.action_space.sample()) (array([ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], dtype=float32), 1.0, False, False, {}) - Note: - This does not support random delay values, if users are interested, please raise an issue or pull request to add this feature. + Change logs: + * v1.0.0 - Initially added """ def __init__(self, env: gym.Env[ObsType, ActType], delay: int): @@ -100,14 +102,14 @@ class DelayObservationV0( return create_zero_array(self.observation_space) -class TimeAwareObservationV0( +class TimeAwareObservation( gym.ObservationWrapper[WrapperObsType, ActType, ObsType], gym.utils.RecordConstructorArgs, ): - """Augment the observation with time information of the episode. + """Augment the observation with the number of time steps taken within an episode. The :attr:`normalize_time` if ``True`` represents time as a normalized value between [0,1] - otherwise if ``False``, the number of timesteps remaining before truncation occurs is an integer. + otherwise if ``False``, the current timestep is an integer. For environments with ``Dict`` observation spaces, the time information is automatically added in the key `"time"` (can be changed through :attr:`dict_time_key`) and for environments with ``Tuple`` @@ -118,33 +120,26 @@ class TimeAwareObservationV0( To flatten the observation, use the :attr:`flatten` parameter which will use the :func:`gymnasium.spaces.utils.flatten` function. + No vector version of the wrapper exists. + Example: >>> import gymnasium as gym - >>> from gymnasium.experimental.wrappers import TimeAwareObservationV0 + >>> from gymnasium.wrappers import TimeAwareObservation >>> env = gym.make("CartPole-v1") - >>> env = TimeAwareObservationV0(env) + >>> env = TimeAwareObservation(env) >>> env.observation_space - Dict('obs': Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), 'time': Box(0.0, 1.0, (1,), float32)) + Box([-4.80000019e+00 -3.40282347e+38 -4.18879032e-01 -3.40282347e+38 + 0.00000000e+00], [4.80000019e+00 3.40282347e+38 4.18879032e-01 3.40282347e+38 + 5.00000000e+02], (5,), float64) >>> env.reset(seed=42)[0] - {'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), 'time': array([0.], dtype=float32)} + array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 , 0. ]) >>> _ = env.action_space.seed(42) >>> env.step(env.action_space.sample())[0] - {'obs': array([ 0.02727336, -0.20172954, 0.03625453, 0.32351476], dtype=float32), 'time': array([0.002], dtype=float32)} + array([ 0.02727336, -0.20172954, 0.03625453, 0.32351476, 1. ]) - Unnormalize time observation space example: + Normalize time observation space example: >>> env = gym.make('CartPole-v1') - >>> env = TimeAwareObservationV0(env, normalize_time=False) - >>> env.observation_space - Dict('obs': Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), 'time': Box(0, 500, (1,), int32)) - >>> env.reset(seed=42)[0] - {'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), 'time': array([500], dtype=int32)} - >>> _ = env.action_space.seed(42)[0] - >>> env.step(env.action_space.sample())[0] - {'obs': array([ 0.02727336, -0.20172954, 0.03625453, 0.32351476], dtype=float32), 'time': array([499], dtype=int32)} - - Flatten observation space example: - >>> env = gym.make("CartPole-v1") - >>> env = TimeAwareObservationV0(env, flatten=True) + >>> env = TimeAwareObservation(env, normalize_time=True) >>> env.observation_space Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38 0.0000000e+00], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38 1.0000000e+00], (5,), float32) @@ -155,17 +150,32 @@ class TimeAwareObservationV0( >>> env.step(env.action_space.sample())[0] array([ 0.02727336, -0.20172954, 0.03625453, 0.32351476, 0.002 ], dtype=float32) + + Flatten observation space example: + >>> env = gym.make("CartPole-v1") + >>> env = TimeAwareObservation(env, flatten=False) + >>> env.observation_space + Dict('obs': Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), 'time': Box(0, 500, (1,), int32)) + >>> env.reset(seed=42)[0] + {'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), 'time': array([0], dtype=int32)} + >>> _ = env.action_space.seed(42) + >>> env.step(env.action_space.sample())[0] + {'obs': array([ 0.02727336, -0.20172954, 0.03625453, 0.32351476], dtype=float32), 'time': array([1], dtype=int32)} + + Change logs: + * v0.18.0 - Initially added + * v1.0.0 - Remove vector environment support, add ``flatten`` and ``normalize_time`` parameters """ def __init__( self, env: gym.Env[ObsType, ActType], - flatten: bool = False, - normalize_time: bool = True, + flatten: bool = True, + normalize_time: bool = False, *, dict_time_key: str = "time", ): - """Initialize :class:`TimeAwareObservationV0`. + """Initialize :class:`TimeAwareObservation`. Args: env: The environment to apply the wrapper @@ -193,7 +203,7 @@ class TimeAwareObservationV0( "The environment must be wrapped by a TimeLimit wrapper or the spec specify a `max_episode_steps`." ) - self._timesteps: int = 0 + self.timesteps: int = 0 # Find the normalized time space if self.normalize_time: @@ -202,9 +212,7 @@ class TimeAwareObservationV0( ) time_space = Box(0.0, 1.0) else: - self._time_preprocess_func = lambda time: np.array( - [self.max_timesteps - time], dtype=np.int32 - ) + self._time_preprocess_func = lambda time: np.array([time], dtype=np.int32) time_space = Box(0, self.max_timesteps, dtype=np.int32) # Find the observation space @@ -244,7 +252,7 @@ class TimeAwareObservationV0( """ return self._obs_postprocess_func( self._append_data_func( - observation, self._time_preprocess_func(self._timesteps) + observation, self._time_preprocess_func(self.timesteps) ) ) @@ -259,7 +267,7 @@ class TimeAwareObservationV0( Returns: The environment's step using the action with the next observation containing the timestep info """ - self._timesteps += 1 + self.timesteps += 1 return super().step(action) @@ -275,36 +283,42 @@ class TimeAwareObservationV0( Returns: Resets the environment with the initial timestep info added the observation """ - self._timesteps = 0 + self.timesteps = 0 return super().reset(seed=seed, options=options) -class FrameStackObservationV0( +class FrameStackObservation( gym.Wrapper[WrapperObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs, ): - """Observation wrapper that stacks the observations in a rolling manner. + """Stacks the observations from the last ``N`` time steps in a rolling manner. For example, if the number of stacks is 4, then the returned observation contains the most recent 4 observations. For environment 'Pendulum-v1', the original observation is an array with shape [3], so if we stack 4 observations, the processed observation has shape [4, 3]. + No vector version of the wrapper exists. + Note: - After :meth:`reset` is called, the frame buffer will be filled with the initial observation. I.e. the observation returned by :meth:`reset` will consist of `num_stack` many identical frames. Example: >>> import gymnasium as gym - >>> from gymnasium.experimental.wrappers import FrameStackObservationV0 + >>> from gymnasium.wrappers import FrameStackObservation >>> env = gym.make("CarRacing-v2") - >>> env = FrameStackObservationV0(env, 4) + >>> env = FrameStackObservation(env, 4) >>> env.observation_space Box(0, 255, (4, 96, 96, 3), uint8) >>> obs, _ = env.reset() >>> obs.shape (4, 96, 96, 3) + + Change logs: + * v0.15.0 - Initially add as ``FrameStack`` with support for lz4 + * v1.0.0 - Rename to ``FrameStackObservation`` and remove lz4 and ``LazyFrame`` support """ def __init__( @@ -392,23 +406,49 @@ class FrameStackObservationV0( return updated_obs, info -class NormalizeObservationV0( +class NormalizeObservation( gym.ObservationWrapper[WrapperObsType, ActType, ObsType], gym.utils.RecordConstructorArgs, ): - """This wrapper will normalize observations s.t. each coordinate is centered with unit variance. + """Normalizes observations to be centered at the mean with unit variance. - The property `_update_running_mean` allows to freeze/continue the running mean calculation of the observation - statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.observation()` is called. - If `False`, the calculated statistics are used but not updated anymore; this may be used during evaluation. + The property :prop:`_update_running_mean` allows to freeze/continue the running mean calculation of the observation + statistics. If ``True`` (default), the ``RunningMeanStd`` will get updated every time ``step`` or ``reset`` is called. + If ``False``, the calculated statistics are used but not updated anymore; this may be used during evaluation. + + A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.NormalizeObservation`. Note: The normalization depends on past trajectories and observations will not be normalized correctly if the wrapper was newly instantiated or the policy was changed recently. + + Example: + >>> import numpy as np + >>> import gymnasium as gym + >>> env = gym.make("CartPole-v1") + >>> obs, info = env.reset(seed=123) + >>> term, trunc = False, False + >>> while not (term or trunc): + ... obs, _, term, trunc, _ = env.step(1) + ... + >>> obs + array([ 0.1511158 , 1.7183299 , -0.25533703, -2.8914354 ], dtype=float32) + >>> env = gym.make("CartPole-v1") + >>> env = NormalizeObservation(env) + >>> obs, info = env.reset(seed=123) + >>> term, trunc = False, False + >>> while not (term or trunc): + ... obs, _, term, trunc, _ = env.step(1) + >>> obs + array([ 2.0059888, 1.5676788, -1.9944268, -1.6120394], dtype=float32) + + Change logs: + * v0.21.0 - Initially add + * v1.0.0 - Add `update_running_mean` attribute to allow disabling of updating the running mean / standard """ def __init__(self, env: gym.Env[ObsType, ActType], epsilon: float = 1e-8): - """This wrapper will normalize observations s.t. each coordinate is centered with unit variance. + """This wrapper will normalize observations such that each observation is centered with unit variance. Args: env (Env): The environment to apply the wrapper @@ -417,7 +457,9 @@ class NormalizeObservationV0( gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon) gym.ObservationWrapper.__init__(self, env) - self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) + self.obs_rms = RunningMeanStd( + shape=self.observation_space.shape, dtype=self.observation_space.dtype + ) self.epsilon = epsilon self._update_running_mean = True @@ -434,19 +476,43 @@ class NormalizeObservationV0( def observation(self, observation: ObsType) -> WrapperObsType: """Normalises the observation using the running mean and variance of the observations.""" if self._update_running_mean: - self.obs_rms.update(observation) + self.obs_rms.update(np.array([observation])) return (observation - self.obs_rms.mean) / np.sqrt( self.obs_rms.var + self.epsilon ) -class MaxAndSkipObservationV0( +class MaxAndSkipObservation( gym.Wrapper[WrapperObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs, ): - """This wrapper will return only every ``skip``-th frame (frameskipping) and return the max between the two last observations. + """Skips the N-th frame (observation) and return the max values between the two last observations. - Note: This wrapper is based on the wrapper from stable-baselines3: https://stable-baselines3.readthedocs.io/en/master/_modules/stable_baselines3/common/atari_wrappers.html#MaxAndSkipEnv + No vector version of the wrapper exists. + + Note: + This wrapper is based on the wrapper from [stable-baselines3](https://stable-baselines3.readthedocs.io/en/master/_modules/stable_baselines3/common/atari_wrappers.html#MaxAndSkipEnv) + + Example: + >>> import gymnasium as gym + >>> env = gym.make("CartPole-v1") + >>> obs0, *_ = env.reset(seed=123) + >>> obs1, *_ = env.step(1) + >>> obs2, *_ = env.step(1) + >>> obs3, *_ = env.step(1) + >>> obs4, *_ = env.step(1) + >>> skip_and_max_obs = np.max(np.stack([obs3, obs4], axis=0), axis=0) + >>> env = gym.make("CartPole-v1") + >>> wrapped_env = MaxAndSkipObservation(env) + >>> wrapped_obs0, *_ = wrapped_env.reset(seed=123) + >>> wrapped_obs1, *_ = wrapped_env.step(1) + >>> np.all(obs0 == wrapped_obs0) + True + >>> np.all(wrapped_obs1 == skip_and_max_obs) + True + + Change logs: + * v1.0.0 - Initially add """ def __init__(self, env: gym.Env[ObsType, ActType], skip: int = 4): @@ -492,14 +558,13 @@ class MaxAndSkipObservationV0( info = {} for i in range(self._skip): obs, reward, terminated, truncated, info = self.env.step(action) - done = terminated or truncated if i == self._skip - 2: self._obs_buffer[0] = obs if i == self._skip - 1: self._obs_buffer[1] = obs total_reward += float(reward) - if done: + if terminated or truncated: break - max_frame = self._obs_buffer.max(axis=0) + max_frame = np.max(self._obs_buffer, axis=0) return max_frame, total_reward, terminated, truncated, info diff --git a/gymnasium/experimental/wrappers/stateful_reward.py b/gymnasium/wrappers/stateful_reward.py similarity index 52% rename from gymnasium/experimental/wrappers/stateful_reward.py rename to gymnasium/wrappers/stateful_reward.py index 2ce274af7..053050d2c 100644 --- a/gymnasium/experimental/wrappers/stateful_reward.py +++ b/gymnasium/wrappers/stateful_reward.py @@ -1,6 +1,6 @@ """A collection of wrappers for modifying the reward with an internal state. -* ``NormalizeRewardV1`` - Normalizes the rewards to a mean and standard deviation +* ``NormalizeReward`` - Normalizes the rewards to a mean and standard deviation """ from __future__ import annotations @@ -10,16 +10,16 @@ import numpy as np import gymnasium as gym from gymnasium.core import ActType, ObsType -from gymnasium.experimental.wrappers.utils import RunningMeanStd +from gymnasium.wrappers.utils import RunningMeanStd -__all__ = ["NormalizeRewardV1"] +__all__ = ["NormalizeReward"] -class NormalizeRewardV1( +class NormalizeReward( gym.Wrapper[ObsType, ActType, ObsType, ActType], gym.utils.RecordConstructorArgs ): - r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. + r"""Normalizes immediate rewards such that their exponential moving average has a fixed variance. The exponential moving average will have variance :math:`(1 - \gamma)^2`. @@ -27,13 +27,53 @@ class NormalizeRewardV1( statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.normalize()` is called. If False, the calculated statistics are used but not updated anymore; this may be used during evaluation. + A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.NormalizeReward`. + Note: - In v0.27, NormalizeReward was updated as the forward discounted reward estimate was incorrect computed in Gym v0.25+. + In v0.27, NormalizeReward was updated as the forward discounted reward estimate was incorrectly computed in Gym v0.25+. For more detail, read [#3154](https://github.com/openai/gym/pull/3152). Note: The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly instantiated or the policy was changed recently. + + Example without the normalize reward wrapper: + >>> import numpy as np + >>> import gymnasium as gym + >>> env = gym.make("MountainCarContinuous-v0") + >>> _ = env.reset(seed=123) + >>> _ = env.action_space.seed(123) + >>> episode_rewards = [] + >>> terminated, truncated = False, False + >>> while not (terminated or truncated): + ... observation, reward, terminated, truncated, info = env.step(env.action_space.sample()) + ... episode_rewards.append(reward) + ... + >>> env.close() + >>> np.var(episode_rewards) + 0.0008876301247721108 + + Example with the normalize reward wrapper: + >>> import numpy as np + >>> import gymnasium as gym + >>> env = gym.make("MountainCarContinuous-v0") + >>> env = NormalizeReward(env, gamma=0.99, epsilon=1e-8) + >>> _ = env.reset(seed=123) + >>> _ = env.action_space.seed(123) + >>> episode_rewards = [] + >>> terminated, truncated = False, False + >>> while not (terminated or truncated): + ... observation, reward, terminated, truncated, info = env.step(env.action_space.sample()) + ... episode_rewards.append(reward) + ... + >>> env.close() + >>> # will approach 0.99 with more episodes + >>> np.var(episode_rewards) + 0.010162116476634746 + + Change logs: + * v0.21.0 - Initially added + * v1.0.0 - Add `update_running_mean` attribute to allow disabling of updating the running mean / standard """ def __init__( @@ -52,7 +92,7 @@ class NormalizeRewardV1( gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon) gym.Wrapper.__init__(self, env) - self.rewards_running_means = RunningMeanStd(shape=()) + self.return_rms = RunningMeanStd(shape=()) self.discounted_reward: np.array = np.array([0.0]) self.gamma = gamma self.epsilon = epsilon @@ -73,13 +113,14 @@ class NormalizeRewardV1( ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """Steps through the environment, normalizing the reward returned.""" obs, reward, terminated, truncated, info = super().step(action) + + # Using the `discounted_reward` rather than `reward` makes no sense but for backward compatibility, it is being kept self.discounted_reward = self.discounted_reward * self.gamma * ( 1 - terminated ) + float(reward) - return obs, self.normalize(float(reward)), terminated, truncated, info - - def normalize(self, reward: SupportsFloat): - """Normalizes the rewards with the running mean rewards and their variance.""" if self._update_running_mean: - self.rewards_running_means.update(self.discounted_reward) - return reward / np.sqrt(self.rewards_running_means.var + self.epsilon) + self.return_rms.update(self.discounted_reward) + + # We don't (reward - self.return_rms.mean) see https://github.com/openai/baselines/issues/538 + normalized_reward = reward / np.sqrt(self.return_rms.var + self.epsilon) + return obs, normalized_reward, terminated, truncated, info diff --git a/gymnasium/wrappers/step_api_compatibility.py b/gymnasium/wrappers/step_api_compatibility.py deleted file mode 100644 index f6f7ddbc2..000000000 --- a/gymnasium/wrappers/step_api_compatibility.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Implementation of StepAPICompatibility wrapper class for transforming envs between new and old step API.""" -import gymnasium as gym -from gymnasium.logger import deprecation -from gymnasium.utils.step_api_compatibility import step_api_compatibility - - -class StepAPICompatibility(gym.Wrapper, gym.utils.RecordConstructorArgs): - r"""A wrapper which can transform an environment from new step API to old and vice-versa. - - Old step API refers to step() method returning (observation, reward, done, info) - New step API refers to step() method returning (observation, reward, terminated, truncated, info) - (Refer to docs for details on the API change) - - Example: - >>> import gymnasium as gym - >>> from gymnasium.wrappers import StepAPICompatibility - >>> env = gym.make("CartPole-v1") - >>> env # wrapper not applied by default, set to new API - >>>> - >>> env = StepAPICompatibility(gym.make("CartPole-v1")) - >>> env - >>>>> - """ - - def __init__(self, env: gym.Env, output_truncation_bool: bool = True): - """A wrapper which can transform an environment from new step API to old and vice-versa. - - Args: - env (gym.Env): the env to wrap. Can be in old or new API - output_truncation_bool (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) - """ - gym.utils.RecordConstructorArgs.__init__( - self, output_truncation_bool=output_truncation_bool - ) - gym.Wrapper.__init__(self, env) - - self.is_vector_env = isinstance(env.unwrapped, gym.vector.VectorEnv) - self.output_truncation_bool = output_truncation_bool - if not self.output_truncation_bool: - deprecation( - "Initializing environment in (old) done step API which returns one bool instead of two." - ) - - def step(self, action): - """Steps through the environment, returning 5 or 4 items depending on `output_truncation_bool`. - - Args: - action: action to step through the environment with - - Returns: - (observation, reward, terminated, truncated, info) or (observation, reward, done, info) - """ - step_returns = self.env.step(action) - return step_api_compatibility( - step_returns, self.output_truncation_bool, self.is_vector_env - ) diff --git a/gymnasium/wrappers/time_aware_observation.py b/gymnasium/wrappers/time_aware_observation.py deleted file mode 100644 index 491425b5e..000000000 --- a/gymnasium/wrappers/time_aware_observation.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Wrapper for adding time aware observations to environment observation.""" -import numpy as np - -import gymnasium as gym -from gymnasium.spaces import Box - - -class TimeAwareObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): - """Augment the observation with the current time step in the episode. - - The observation space of the wrapped environment is assumed to be a flat :class:`Box`. - In particular, pixel observations are not supported. This wrapper will append the current timestep within the current episode to the observation. - - Example: - >>> import gymnasium as gym - >>> from gymnasium.wrappers import TimeAwareObservation - >>> env = gym.make("CartPole-v1") - >>> env = TimeAwareObservation(env) - >>> env.reset(seed=42) - (array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 , 0. ]), {}) - >>> _ = env.action_space.seed(42) - >>> env.step(env.action_space.sample())[0] - array([ 0.02727336, -0.20172954, 0.03625453, 0.32351476, 1. ]) - """ - - def __init__(self, env: gym.Env): - """Initialize :class:`TimeAwareObservation` that requires an environment with a flat :class:`Box` observation space. - - Args: - env: The environment to apply the wrapper - """ - gym.utils.RecordConstructorArgs.__init__(self) - gym.ObservationWrapper.__init__(self, env) - - assert isinstance(env.observation_space, Box) - assert env.observation_space.dtype == np.float32 - low = np.append(self.observation_space.low, 0.0) - high = np.append(self.observation_space.high, np.inf) - self.observation_space = Box(low, high, dtype=np.float32) - - try: - self.is_vector_env = self.get_wrapper_attr("is_vector_env") - except AttributeError: - self.is_vector_env = False - - def observation(self, observation): - """Adds to the observation with the current time step. - - Args: - observation: The observation to add the time step to - - Returns: - The observation with the time step appended to - """ - return np.append(observation, self.t) - - def step(self, action): - """Steps through the environment, incrementing the time step. - - Args: - action: The action to take - - Returns: - The environment's step using the action. - """ - self.t += 1 - return super().step(action) - - def reset(self, **kwargs): - """Reset the environment setting the time to zero. - - Args: - **kwargs: Kwargs to apply to env.reset() - - Returns: - The reset environment - """ - self.t = 0 - return super().reset(**kwargs) diff --git a/gymnasium/wrappers/time_limit.py b/gymnasium/wrappers/time_limit.py deleted file mode 100644 index 48d564ad9..000000000 --- a/gymnasium/wrappers/time_limit.py +++ /dev/null @@ -1,89 +0,0 @@ -"""Wrapper for limiting the time steps of an environment.""" -from __future__ import annotations - -from copy import deepcopy -from typing import TYPE_CHECKING - -import gymnasium as gym - - -if TYPE_CHECKING: - from gymnasium.envs.registration import EnvSpec - - -class TimeLimit(gym.Wrapper, gym.utils.RecordConstructorArgs): - """This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded. - - If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued. - Critically, this is different from the `terminated` signal that originates from the underlying environment as part of the MDP. - - Example: - >>> import gymnasium as gym - >>> from gymnasium.wrappers import TimeLimit - >>> env = gym.make("CartPole-v1") - >>> env = TimeLimit(env, max_episode_steps=1000) - """ - - def __init__( - self, - env: gym.Env, - max_episode_steps: int, - ): - """Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur. - - Args: - env: The environment to apply the wrapper - max_episode_steps: An optional max episode steps (if ``None``, ``env.spec.max_episode_steps`` is used) - """ - gym.utils.RecordConstructorArgs.__init__( - self, max_episode_steps=max_episode_steps - ) - gym.Wrapper.__init__(self, env) - - self._max_episode_steps = max_episode_steps - self._elapsed_steps = None - - def step(self, action): - """Steps through the environment and if the number of steps elapsed exceeds ``max_episode_steps`` then truncate. - - Args: - action: The environment step action - - Returns: - The environment step ``(observation, reward, terminated, truncated, info)`` with `truncated=True` - if the number of steps elapsed >= max episode steps - - """ - observation, reward, terminated, truncated, info = self.env.step(action) - self._elapsed_steps += 1 - - if self._elapsed_steps >= self._max_episode_steps: - truncated = True - - return observation, reward, terminated, truncated, info - - def reset(self, **kwargs): - """Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero. - - Args: - **kwargs: The kwargs to reset the environment with - - Returns: - The reset environment - """ - self._elapsed_steps = 0 - return self.env.reset(**kwargs) - - @property - def spec(self) -> EnvSpec | None: - """Modifies the environment spec to include the `max_episode_steps=self._max_episode_steps`.""" - if self._cached_spec is not None: - return self._cached_spec - - env_spec = self.env.spec - if env_spec is not None: - env_spec = deepcopy(env_spec) - env_spec.max_episode_steps = self._max_episode_steps - - self._cached_spec = env_spec - return env_spec diff --git a/gymnasium/experimental/wrappers/lambda_action.py b/gymnasium/wrappers/transform_action.py similarity index 69% rename from gymnasium/experimental/wrappers/lambda_action.py rename to gymnasium/wrappers/transform_action.py index 510fc749d..a493d47b4 100644 --- a/gymnasium/experimental/wrappers/lambda_action.py +++ b/gymnasium/wrappers/transform_action.py @@ -1,8 +1,8 @@ """A collection of wrappers that all use the LambdaAction class. -* ``LambdaActionV0`` - Transforms the actions based on a function -* ``ClipActionV0`` - Clips the action within a bounds -* ``RescaleActionV0`` - Rescales the action within a minimum and maximum actions +* ``TransformAction`` - Transforms the actions based on a function +* ``ClipAction`` - Clips the action within a bounds +* ``RescaleAction`` - Rescales the action within a minimum and maximum actions """ from __future__ import annotations @@ -15,13 +15,34 @@ from gymnasium.core import ActType, ObsType, WrapperActType from gymnasium.spaces import Box, Space -__all__ = ["LambdaActionV0", "ClipActionV0", "RescaleActionV0"] +__all__ = ["TransformAction", "ClipAction", "RescaleAction"] -class LambdaActionV0( +class TransformAction( gym.ActionWrapper[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs ): - """A wrapper that provides a function to modify the action passed to :meth:`step`.""" + """Applies a function to the ``action`` before passing the modified value to the environment ``step`` function. + + A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.TransformAction`. + + Example: + >>> import numpy as np + >>> import gymnasium as gym + >>> env = gym.make("MountainCarContinuous-v0") + >>> _ = env.reset(seed=123) + >>> obs, *_= env.step(np.array([0.0, 1.0])) + >>> obs + array([-4.6397772e-01, -4.4808415e-04], dtype=float32) + >>> env = gym.make("MountainCarContinuous-v0") + >>> env = TransformAction(env, lambda a: 0.5 * a + 0.1, env.action_space) + >>> _ = env.reset(seed=123) + >>> obs, *_= env.step(np.array([0.0, 1.0])) + >>> obs + array([-4.6382770e-01, -2.9808417e-04], dtype=float32) + + Change logs: + * v1.0.0 - Initially added + """ def __init__( self, @@ -29,7 +50,7 @@ class LambdaActionV0( func: Callable[[WrapperActType], ActType], action_space: Space[WrapperActType] | None, ): - """Initialize LambdaAction. + """Initialize TransformAction. Args: env: The environment to wrap @@ -51,22 +72,28 @@ class LambdaActionV0( return self.func(action) -class ClipActionV0( - LambdaActionV0[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs +class ClipAction( + TransformAction[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs ): - """Clip the continuous action within the valid :class:`Box` observation space bound. + """Clips the ``action`` pass to ``step`` to be within the environment's `action_space`. + + A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.ClipAction`. Example: >>> import gymnasium as gym - >>> from gymnasium.experimental.wrappers import ClipActionV0 + >>> from gymnasium.wrappers import ClipAction >>> import numpy as np >>> env = gym.make("Hopper-v4", disable_env_checker=True) - >>> env = ClipActionV0(env) + >>> env = ClipAction(env) >>> env.action_space Box(-inf, inf, (3,), float32) >>> _ = env.reset(seed=42) >>> _ = env.step(np.array([5.0, -2.0, 0.0], dtype=np.float32)) ... # Executes the action np.array([1.0, -1.0, 0]) in the base environment + + Change logs: + * v0.12.6 - Initially added + * v1.0.0 - Action space is updated to infinite bounds as is technically correct """ def __init__(self, env: gym.Env[ObsType, ActType]): @@ -78,7 +105,7 @@ class ClipActionV0( assert isinstance(env.action_space, Box) gym.utils.RecordConstructorArgs.__init__(self) - LambdaActionV0.__init__( + TransformAction.__init__( self, env=env, func=lambda action: np.clip( @@ -93,17 +120,19 @@ class ClipActionV0( ) -class RescaleActionV0( - LambdaActionV0[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs +class RescaleAction( + TransformAction[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs ): - """Affinely rescales the continuous action space of the environment to the range [min_action, max_action]. + """Affinely (linearly) rescales a ``Box`` action space of the environment to within the range of ``[min_action, max_action]``. The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action` or :attr:`max_action` are numpy arrays, the shape must match the shape of the environment's action space. + A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.RescaleAction`. + Example: >>> import gymnasium as gym - >>> from gymnasium.experimental.wrappers import RescaleActionV0 + >>> from gymnasium.wrappers import RescaleAction >>> import numpy as np >>> env = gym.make("Hopper-v4", disable_env_checker=True) >>> _ = env.reset(seed=42) @@ -111,10 +140,13 @@ class RescaleActionV0( >>> _ = env.reset(seed=42) >>> min_action = -0.5 >>> max_action = np.array([0.0, 0.5, 0.75], dtype=np.float32) - >>> wrapped_env = RescaleActionV0(env, min_action=min_action, max_action=max_action) + >>> wrapped_env = RescaleAction(env, min_action=min_action, max_action=max_action) >>> wrapped_env_obs, _, _, _, _ = wrapped_env.step(max_action) - >>> np.alltrue(obs == wrapped_env_obs) + >>> np.all(obs == wrapped_env_obs) True + + Change logs: + * v0.15.4 - Initially added """ def __init__( @@ -165,7 +197,7 @@ class RescaleActionV0( ) intercept = gradient * -min_action + env.action_space.low - LambdaActionV0.__init__( + TransformAction.__init__( self, env=env, func=lambda action: gradient * action + intercept, diff --git a/gymnasium/wrappers/transform_observation.py b/gymnasium/wrappers/transform_observation.py index d70afe400..9d8c3a51d 100644 --- a/gymnasium/wrappers/transform_observation.py +++ b/gymnasium/wrappers/transform_observation.py @@ -1,15 +1,50 @@ -"""Wrapper for transforming observations.""" -from typing import Any, Callable +"""A collection of observation wrappers using a lambda function. + +* ``TransformObservation`` - Transforms the observation with a function +* ``FilterObservation`` - Filters a ``Tuple`` or ``Dict`` to only include certain keys +* ``FlattenObservation`` - Flattens the observations +* ``GrayscaleObservation`` - Converts a RGB observation to a grayscale observation +* ``ResizeObservation`` - Resizes an array-based observation (normally a RGB observation) +* ``ReshapeObservation`` - Reshapes an array-based observation +* ``RescaleObservation`` - Rescales an observation to between a minimum and maximum value +* ``DtypeObservation`` - Convert an observation to a dtype +* ``RenderObservation`` - Allows the observation to the rendered frame +""" +from __future__ import annotations + +from typing import Any, Callable, Final, Sequence + +import numpy as np import gymnasium as gym +from gymnasium import spaces +from gymnasium.core import ActType, ObsType, WrapperObsType +from gymnasium.error import DependencyNotInstalled -class TransformObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): - """Transform the observation via an arbitrary function :attr:`f`. +__all__ = [ + "TransformObservation", + "FilterObservation", + "FlattenObservation", + "GrayscaleObservation", + "ResizeObservation", + "ReshapeObservation", + "RescaleObservation", + "DtypeObservation", + "RenderObservation", +] - The function :attr:`f` should be defined on the observation space of the base environment, ``env``, and should, ideally, return values in the same space. - If the transformation you wish to apply to observations returns values in a *different* space, you should subclass :class:`ObservationWrapper`, implement the transformation, and set the new observation space accordingly. If you were to use this wrapper instead, the observation space would be set incorrectly. +class TransformObservation( + gym.ObservationWrapper[WrapperObsType, ActType, ObsType], + gym.utils.RecordConstructorArgs, +): + """Applies a function to the ``observation`` received from the environment's :meth:`Env.reset` and :meth:`Env.step` that is passed back to the user. + + The function :attr:`func` will be applied to all observations. + If the observations from :attr:`func` are outside the bounds of the ``env``'s observation space, provide an updated :attr:`observation_space`. + + A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.TransformObservation`. Example: >>> import gymnasium as gym @@ -17,31 +52,664 @@ class TransformObservation(gym.ObservationWrapper, gym.utils.RecordConstructorAr >>> import numpy as np >>> np.random.seed(0) >>> env = gym.make("CartPole-v1") - >>> env = TransformObservation(env, lambda obs: obs + 0.1 * np.random.randn(*obs.shape)) >>> env.reset(seed=42) - (array([0.20380084, 0.03390356, 0.13373359, 0.24382612]), {}) + (array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), {}) + >>> env = gym.make("CartPole-v1") + >>> env = TransformObservation(env, lambda obs: obs + 0.1 * np.random.random(obs.shape), env.observation_space) + >>> env.reset(seed=42) + (array([0.08227695, 0.06540678, 0.09613613, 0.07422512]), {}) + + Change logs: + * v0.15.4 - Initially added + * v1.0.0 - Add requirement of ``observation_space`` """ - def __init__(self, env: gym.Env, f: Callable[[Any], Any]): - """Initialize the :class:`TransformObservation` wrapper with an environment and a transform function :attr:`f`. + def __init__( + self, + env: gym.Env[ObsType, ActType], + func: Callable[[ObsType], Any], + observation_space: gym.Space[WrapperObsType] | None, + ): + """Constructor for the transform observation wrapper. Args: - env: The environment to apply the wrapper - f: A function that transforms the observation + env: The environment to wrap + func: A function that will transform an observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an `observation_space`. + observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``. """ - gym.utils.RecordConstructorArgs.__init__(self, f=f) + gym.utils.RecordConstructorArgs.__init__( + self, func=func, observation_space=observation_space + ) gym.ObservationWrapper.__init__(self, env) - assert callable(f) - self.f = f + if observation_space is not None: + self.observation_space = observation_space - def observation(self, observation): - """Transforms the observations with callable :attr:`f`. + self.func = func + + def observation(self, observation: ObsType) -> Any: + """Apply function to the observation.""" + return self.func(observation) + + +class FilterObservation( + TransformObservation[WrapperObsType, ActType, ObsType], + gym.utils.RecordConstructorArgs, +): + """Filters a Dict or Tuple observation spaces by a set of keys or indexes. + + A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.FilterObservation`. + + Example: + >>> import gymnasium as gym + >>> from gymnasium.wrappers import FilterObservation + >>> env = gym.make("CartPole-v1") + >>> env = gym.wrappers.TimeAwareObservation(env, flatten=False) + >>> env.observation_space + Dict('obs': Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32), 'time': Box(0, 500, (1,), int32)) + >>> env.reset(seed=42) + ({'obs': array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32), 'time': array([0], dtype=int32)}, {}) + >>> env = FilterObservation(env, filter_keys=['time']) + >>> env.reset(seed=42) + ({'time': array([0], dtype=int32)}, {}) + >>> env.step(0) + ({'time': array([1], dtype=int32)}, 1.0, False, False, {}) + + Change logs: + * v0.12.3 - Initially added, originally called `FilterObservationWrapper` + * v1.0.0 - Rename to `FilterObservation` and add support for tuple observation spaces with integer ``filter_keys`` + """ + + def __init__( + self, env: gym.Env[ObsType, ActType], filter_keys: Sequence[str | int] + ): + """Constructor for the filter observation wrapper. Args: - observation: The observation to transform - - Returns: - The transformed observation + env: The environment to wrap + filter_keys: The set of subspaces to be *included*, use a list of strings for ``Dict`` and integers for ``Tuple`` spaces """ - return self.f(observation) + if not isinstance(filter_keys, Sequence): + raise TypeError( + f"Expects `filter_keys` to be a Sequence, actual type: {type(filter_keys)}" + ) + gym.utils.RecordConstructorArgs.__init__(self, filter_keys=filter_keys) + + # Filters for dictionary space + if isinstance(env.observation_space, spaces.Dict): + assert all(isinstance(key, str) for key in filter_keys) + + if any( + key not in env.observation_space.spaces.keys() for key in filter_keys + ): + missing_keys = [ + key + for key in filter_keys + if key not in env.observation_space.spaces.keys() + ] + raise ValueError( + "All the `filter_keys` must be included in the observation space.\n" + f"Filter keys: {filter_keys}\n" + f"Observation keys: {list(env.observation_space.spaces.keys())}\n" + f"Missing keys: {missing_keys}" + ) + + new_observation_space = spaces.Dict( + {key: env.observation_space[key] for key in filter_keys} + ) + if len(new_observation_space) == 0: + raise ValueError( + "The observation space is empty due to filtering all of the keys." + ) + + TransformObservation.__init__( + self, + env=env, + func=lambda obs: {key: obs[key] for key in filter_keys}, + observation_space=new_observation_space, + ) + # Filter for tuple observation + elif isinstance(env.observation_space, spaces.Tuple): + assert all(isinstance(key, int) for key in filter_keys) + assert len(set(filter_keys)) == len( + filter_keys + ), f"Duplicate keys exist, filter_keys: {filter_keys}" + + if any( + 0 < key and key >= len(env.observation_space) for key in filter_keys + ): + missing_index = [ + key + for key in filter_keys + if 0 < key and key >= len(env.observation_space) + ] + raise ValueError( + "All the `filter_keys` must be included in the length of the observation space.\n" + f"Filter keys: {filter_keys}, length of observation: {len(env.observation_space)}, " + f"missing indexes: {missing_index}" + ) + + new_observation_spaces = spaces.Tuple( + env.observation_space[key] for key in filter_keys + ) + if len(new_observation_spaces) == 0: + raise ValueError( + "The observation space is empty due to filtering all keys." + ) + + TransformObservation.__init__( + self, + env=env, + func=lambda obs: tuple(obs[key] for key in filter_keys), + observation_space=new_observation_spaces, + ) + else: + raise ValueError( + f"FilterObservation wrapper is only usable with `Dict` and `Tuple` observations, actual type: {type(env.observation_space)}" + ) + + self.filter_keys: Final[Sequence[str | int]] = filter_keys + + +class FlattenObservation( + TransformObservation[WrapperObsType, ActType, ObsType], + gym.utils.RecordConstructorArgs, +): + """Flattens the environment's observation space and each observation from ``reset`` and ``step`` functions. + + A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.FlattenObservation`. + + Example: + >>> import gymnasium as gym + >>> from gymnasium.wrappers import FlattenObservation + >>> env = gym.make("CarRacing-v2") + >>> env.observation_space.shape + (96, 96, 3) + >>> env = FlattenObservation(env) + >>> env.observation_space.shape + (27648,) + >>> obs, _ = env.reset() + >>> obs.shape + (27648,) + + Change logs: + * v0.15.0 - Initially added + """ + + def __init__(self, env: gym.Env[ObsType, ActType]): + """Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``. + + Args: + env: The environment to wrap + """ + gym.utils.RecordConstructorArgs.__init__(self) + TransformObservation.__init__( + self, + env=env, + func=lambda obs: spaces.utils.flatten(env.observation_space, obs), + observation_space=spaces.utils.flatten_space(env.observation_space), + ) + + +class GrayscaleObservation( + TransformObservation[WrapperObsType, ActType, ObsType], + gym.utils.RecordConstructorArgs, +): + """Converts an image observation computed by ``reset`` and ``step`` from RGB to Grayscale. + + The :attr:`keep_dim` will keep the channel dimension. + + A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.GrayscaleObservation`. + + Example: + >>> import gymnasium as gym + >>> from gymnasium.wrappers import GrayscaleObservation + >>> env = gym.make("CarRacing-v2") + >>> env.observation_space.shape + (96, 96, 3) + >>> grayscale_env = GrayscaleObservation(env) + >>> grayscale_env.observation_space.shape + (96, 96) + >>> grayscale_env = GrayscaleObservation(env, keep_dim=True) + >>> grayscale_env.observation_space.shape + (96, 96, 1) + + Change logs: + * v0.15.0 - Initially added, originally called ``GrayScaleObservation`` + * v1.0.0 - Renamed to ``GrayscaleObservation`` + """ + + def __init__(self, env: gym.Env[ObsType, ActType], keep_dim: bool = False): + """Constructor for an RGB image based environments to make the image grayscale. + + Args: + env: The environment to wrap + keep_dim: If to keep the channel in the observation, if ``True``, ``obs.shape == 3`` else ``obs.shape == 2`` + """ + assert isinstance(env.observation_space, spaces.Box) + assert ( + len(env.observation_space.shape) == 3 + and env.observation_space.shape[-1] == 3 + ) + assert ( + np.all(env.observation_space.low == 0) + and np.all(env.observation_space.high == 255) + and env.observation_space.dtype == np.uint8 + ) + gym.utils.RecordConstructorArgs.__init__(self, keep_dim=keep_dim) + + self.keep_dim: Final[bool] = keep_dim + if keep_dim: + new_observation_space = spaces.Box( + low=0, + high=255, + shape=env.observation_space.shape[:2] + (1,), + dtype=np.uint8, + ) + TransformObservation.__init__( + self, + env=env, + func=lambda obs: np.expand_dims( + np.sum( + np.multiply(obs, np.array([0.2125, 0.7154, 0.0721])), axis=-1 + ).astype(np.uint8), + axis=-1, + ), + observation_space=new_observation_space, + ) + else: + new_observation_space = spaces.Box( + low=0, high=255, shape=env.observation_space.shape[:2], dtype=np.uint8 + ) + TransformObservation.__init__( + self, + env=env, + func=lambda obs: np.sum( + np.multiply(obs, np.array([0.2125, 0.7154, 0.0721])), axis=-1 + ).astype(np.uint8), + observation_space=new_observation_space, + ) + + +class ResizeObservation( + TransformObservation[WrapperObsType, ActType, ObsType], + gym.utils.RecordConstructorArgs, +): + """Resizes image observations using OpenCV to a specified shape. + + A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.ResizeObservation`. + + Example: + >>> import gymnasium as gym + >>> from gymnasium.wrappers import ResizeObservation + >>> env = gym.make("CarRacing-v2") + >>> env.observation_space.shape + (96, 96, 3) + >>> resized_env = ResizeObservation(env, (32, 32)) + >>> resized_env.observation_space.shape + (32, 32, 3) + + Change logs: + * v0.12.6 - Initially added + * v1.0.0 - Requires ``shape`` with a tuple of two integers + """ + + def __init__(self, env: gym.Env[ObsType, ActType], shape: tuple[int, int]): + """Constructor that requires an image environment observation space with a shape. + + Args: + env: The environment to wrap + shape: The resized observation shape + """ + assert isinstance(env.observation_space, spaces.Box) + assert len(env.observation_space.shape) in {2, 3} + assert np.all(env.observation_space.low == 0) and np.all( + env.observation_space.high == 255 + ) + assert env.observation_space.dtype == np.uint8 + + assert isinstance(shape, tuple) + assert len(shape) == 2 + assert all(np.issubdtype(type(elem), np.integer) for elem in shape) + assert all(x > 0 for x in shape) + + try: + import cv2 + except ImportError as e: + raise DependencyNotInstalled( + "opencv (cv2) is not installed, run `pip install gymnasium[other]`" + ) from e + + self.shape: Final[tuple[int, int]] = tuple(shape) + # for some reason, cv2.resize will return the shape in reverse, todo confirm implementation + self.cv2_shape: Final[tuple[int, int]] = (shape[1], shape[0]) + + new_observation_space = spaces.Box( + low=0, + high=255, + shape=self.shape + env.observation_space.shape[2:], + dtype=np.uint8, + ) + + gym.utils.RecordConstructorArgs.__init__(self, shape=shape) + TransformObservation.__init__( + self, + env=env, + func=lambda obs: cv2.resize( + obs, self.cv2_shape, interpolation=cv2.INTER_AREA + ), + observation_space=new_observation_space, + ) + + +class ReshapeObservation( + TransformObservation[WrapperObsType, ActType, ObsType], + gym.utils.RecordConstructorArgs, +): + """Reshapes Array based observations to a specified shape. + + A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.RescaleObservation`. + + Example: + >>> import gymnasium as gym + >>> from gymnasium.wrappers import ReshapeObservation + >>> env = gym.make("CarRacing-v2") + >>> env.observation_space.shape + (96, 96, 3) + >>> reshape_env = ReshapeObservation(env, (24, 4, 96, 1, 3)) + >>> reshape_env.observation_space.shape + (24, 4, 96, 1, 3) + + Change logs: + * v1.0.0 - Initially added + """ + + def __init__(self, env: gym.Env[ObsType, ActType], shape: int | tuple[int, ...]): + """Constructor for env with ``Box`` observation space that has a shape product equal to the new shape product. + + Args: + env: The environment to wrap + shape: The reshaped observation space + """ + assert isinstance(env.observation_space, spaces.Box) + assert np.prod(shape) == np.prod(env.observation_space.shape) + + assert isinstance(shape, tuple) + assert all(np.issubdtype(type(elem), np.integer) for elem in shape) + assert all(x > 0 or x == -1 for x in shape) + + new_observation_space = spaces.Box( + low=np.reshape(np.ravel(env.observation_space.low), shape), + high=np.reshape(np.ravel(env.observation_space.high), shape), + shape=shape, + dtype=env.observation_space.dtype, + ) + self.shape = shape + + gym.utils.RecordConstructorArgs.__init__(self, shape=shape) + TransformObservation.__init__( + self, + env=env, + func=lambda obs: np.reshape(obs, shape), + observation_space=new_observation_space, + ) + + +class RescaleObservation( + TransformObservation[WrapperObsType, ActType, ObsType], + gym.utils.RecordConstructorArgs, +): + """Affinely (linearly) rescales a ``Box`` observation space of the environment to within the range of ``[min_obs, max_obs]``. + + A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.RescaleObservation`. + + Example: + >>> import gymnasium as gym + >>> from gymnasium.wrappers import RescaleObservation + >>> env = gym.make("Pendulum-v1") + >>> env.observation_space + Box([-1. -1. -8.], [1. 1. 8.], (3,), float32) + >>> env = RescaleObservation(env, np.array([-2, -1, -10], dtype=np.float32), np.array([1, 0, 1], dtype=np.float32)) + >>> env.observation_space + Box([ -2. -1. -10.], [1. 0. 1.], (3,), float32) + + Change logs: + * v1.0.0 - Initially added + """ + + def __init__( + self, + env: gym.Env[ObsType, ActType], + min_obs: np.floating | np.integer | np.ndarray, + max_obs: np.floating | np.integer | np.ndarray, + ): + """Constructor that requires the env observation spaces to be a :class:`Box`. + + Args: + env: The environment to wrap + min_obs: The new minimum observation bound + max_obs: The new maximum observation bound + """ + assert isinstance(env.observation_space, spaces.Box) + assert not np.any(env.observation_space.low == np.inf) and not np.any( + env.observation_space.high == np.inf + ) + + if not isinstance(min_obs, np.ndarray): + assert np.issubdtype(type(min_obs), np.integer) or np.issubdtype( + type(max_obs), np.floating + ) + min_obs = np.full(env.observation_space.shape, min_obs) + assert ( + min_obs.shape == env.observation_space.shape + ), f"{min_obs.shape}, {env.observation_space.shape}, {min_obs}, {env.observation_space.low}" + assert not np.any(min_obs == np.inf) + + if not isinstance(max_obs, np.ndarray): + assert np.issubdtype(type(max_obs), np.integer) or np.issubdtype( + type(max_obs), np.floating + ) + max_obs = np.full(env.observation_space.shape, max_obs) + assert max_obs.shape == env.observation_space.shape + assert not np.any(max_obs == np.inf) + + self.min_obs = min_obs + self.max_obs = max_obs + + # Imagine the x-axis between the old Box and the y-axis being the new Box + high_low_diff = np.array( + env.observation_space.high, dtype=np.float128 + ) - np.array(env.observation_space.low, dtype=np.float128) + gradient = np.array( + (max_obs - min_obs) / high_low_diff, dtype=env.observation_space.dtype + ) + + intercept = gradient * -env.observation_space.low + min_obs + + gym.utils.RecordConstructorArgs.__init__(self, min_obs=min_obs, max_obs=max_obs) + TransformObservation.__init__( + self, + env=env, + func=lambda obs: gradient * obs + intercept, + observation_space=spaces.Box( + low=min_obs, + high=max_obs, + shape=env.observation_space.shape, + dtype=env.observation_space.dtype, + ), + ) + + +class DtypeObservation( + TransformObservation[WrapperObsType, ActType, ObsType], + gym.utils.RecordConstructorArgs, +): + """Modifies the dtype of an observation array to a specified dtype. + + Note: + This is only compatible with :class:`Box`, :class:`Discrete`, :class:`MultiDiscrete` and :class:`MultiBinary` observation spaces + + A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.DtypeObservation`. + + Change logs: + * v1.0.0 - Initially added + """ + + def __init__(self, env: gym.Env[ObsType, ActType], dtype: Any): + """Constructor for Dtype observation wrapper. + + Args: + env: The environment to wrap + dtype: The new dtype of the observation + """ + assert isinstance( + env.observation_space, + (spaces.Box, spaces.Discrete, spaces.MultiDiscrete, spaces.MultiBinary), + ) + + self.dtype = dtype + if isinstance(env.observation_space, spaces.Box): + new_observation_space = spaces.Box( + low=env.observation_space.low, + high=env.observation_space.high, + shape=env.observation_space.shape, + dtype=self.dtype, + ) + elif isinstance(env.observation_space, spaces.Discrete): + new_observation_space = spaces.Box( + low=env.observation_space.start, + high=env.observation_space.start + env.observation_space.n, + shape=(), + dtype=self.dtype, + ) + elif isinstance(env.observation_space, spaces.MultiDiscrete): + new_observation_space = spaces.MultiDiscrete( + env.observation_space.nvec, dtype=dtype + ) + elif isinstance(env.observation_space, spaces.MultiBinary): + new_observation_space = spaces.Box( + low=0, + high=1, + shape=env.observation_space.shape, + dtype=self.dtype, + ) + else: + raise TypeError( + "DtypeObservation is only compatible with value / array-based observations." + ) + + gym.utils.RecordConstructorArgs.__init__(self, dtype=dtype) + TransformObservation.__init__( + self, + env=env, + func=lambda obs: dtype(obs), + observation_space=new_observation_space, + ) + + +class RenderObservation( + TransformObservation[WrapperObsType, ActType, ObsType], + gym.utils.RecordConstructorArgs, +): + """Includes the rendered observations in the environment's observations. + + Notes: + This was previously called ``PixelObservationWrapper``. + + No vector version of the wrapper exists. + + Example - Replace the observation with the rendered image: + >>> env = gym.make("CartPole-v1", render_mode="rgb_array") + >>> env = RenderObservation(env, render_only=True) + >>> env.observation_space + Box(0, 255, (400, 600, 3), uint8) + >>> obs, _ = env.reset(seed=123) + >>> image = env.render() + >>> np.all(obs == image) + True + >>> obs, *_ = env.step(env.action_space.sample()) + >>> image = env.render() + >>> np.all(obs == image) + True + + Example - Add the rendered image to the original observation as a dictionary item: + >>> env = gym.make("CartPole-v1", render_mode="rgb_array") + >>> env = RenderObservation(env, render_only=False) + >>> env.observation_space + Dict('pixels': Box(0, 255, (400, 600, 3), uint8), 'state': Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)) + >>> obs, info = env.reset(seed=123) + >>> obs.keys() + dict_keys(['state', 'pixels']) + >>> obs["state"] + array([ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], dtype=float32) + >>> np.all(obs["pixels"] == env.render()) + True + >>> obs, reward, terminates, truncates, info = env.step(env.action_space.sample()) + >>> image = env.render() + >>> np.all(obs["pixels"] == image) + True + + Change logs: + * v0.15.0 - Initially added as ``PixelObservationWrapper`` + * v1.0.0 - Renamed to ``RenderObservation`` + """ + + def __init__( + self, + env: gym.Env[ObsType, ActType], + render_only: bool = True, + render_key: str = "pixels", + obs_key: str = "state", + ): + """Constructor of the pixel observation wrapper. + + Args: + env: The environment to wrap. + render_only (bool): If ``True`` (default), the original observation returned + by the wrapped environment will be discarded, and a dictionary + observation will only include pixels. If ``False``, the + observation dictionary will contain both the original + observations and the pixel observations. + render_key: Optional custom string specifying the pixel key. Defaults to "pixels" + obs_key: Optional custom string specifying the obs key. Defaults to "state" + """ + gym.utils.RecordConstructorArgs.__init__( + self, + pixels_only=render_only, + pixels_key=render_key, + obs_key=obs_key, + ) + + assert env.render_mode is not None and env.render_mode != "human" + env.reset() + pixels = env.render() + assert pixels is not None and isinstance(pixels, np.ndarray) + pixel_space = spaces.Box(low=0, high=255, shape=pixels.shape, dtype=np.uint8) + + if render_only: + obs_space = pixel_space + TransformObservation.__init__( + self, env=env, func=lambda _: self.render(), observation_space=obs_space + ) + elif isinstance(env.observation_space, spaces.Dict): + assert render_key not in env.observation_space.spaces.keys() + + obs_space = spaces.Dict( + {render_key: pixel_space, **env.observation_space.spaces} + ) + TransformObservation.__init__( + self, + env=env, + func=lambda obs: {render_key: self.render(), **obs}, + observation_space=obs_space, + ) + else: + obs_space = spaces.Dict( + {obs_key: env.observation_space, render_key: pixel_space} + ) + TransformObservation.__init__( + self, + env=env, + func=lambda obs: {obs_key: obs, render_key: self.render()}, + observation_space=obs_space, + ) diff --git a/gymnasium/wrappers/transform_reward.py b/gymnasium/wrappers/transform_reward.py index 1da56f489..aa7078ecc 100644 --- a/gymnasium/wrappers/transform_reward.py +++ b/gymnasium/wrappers/transform_reward.py @@ -1,46 +1,112 @@ -"""Wrapper for transforming the reward.""" -from typing import Callable +"""A collection of wrappers for modifying the reward. + +* ``TransformReward`` - Transforms the reward by a function +* ``ClipReward`` - Clips the reward between a minimum and maximum value +""" +from __future__ import annotations + +from typing import Callable, SupportsFloat + +import numpy as np import gymnasium as gym +from gymnasium.core import ActType, ObsType +from gymnasium.error import InvalidBound -class TransformReward(gym.RewardWrapper, gym.utils.RecordConstructorArgs): - """Transform the reward via an arbitrary function. +__all__ = ["TransformReward", "ClipReward"] - Warning: - If the base environment specifies a reward range which is not invariant under :attr:`f`, the :attr:`reward_range` of the wrapped environment will be incorrect. + +class TransformReward( + gym.RewardWrapper[ObsType, ActType], gym.utils.RecordConstructorArgs +): + """Applies a function to the ``reward`` received from the environment's ``step``. + + A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.TransformReward`. Example: >>> import gymnasium as gym >>> from gymnasium.wrappers import TransformReward >>> env = gym.make("CartPole-v1") - >>> env = TransformReward(env, lambda r: 0.01*r) + >>> env = TransformReward(env, lambda r: 2 * r + 1) >>> _ = env.reset() - >>> observation, reward, terminated, truncated, info = env.step(env.action_space.sample()) - >>> reward - 0.01 + >>> _, rew, _, _, _ = env.step(0) + >>> rew + 3.0 + + Change logs: + * v0.15.0 - Initially added """ - def __init__(self, env: gym.Env, f: Callable[[float], float]): - """Initialize the :class:`TransformReward` wrapper with an environment and reward transform function :attr:`f`. + def __init__( + self, + env: gym.Env[ObsType, ActType], + func: Callable[[SupportsFloat], SupportsFloat], + ): + """Initialize TransformReward wrapper. Args: - env: The environment to apply the wrapper - f: A function that transforms the reward + env (Env): The environment to wrap + func: (Callable): The function to apply to reward """ - gym.utils.RecordConstructorArgs.__init__(self, f=f) + gym.utils.RecordConstructorArgs.__init__(self, func=func) gym.RewardWrapper.__init__(self, env) - assert callable(f) - self.f = f + self.func = func - def reward(self, reward): - """Transforms the reward using callable :attr:`f`. + def reward(self, reward: SupportsFloat) -> SupportsFloat: + """Apply function to reward. Args: - reward: The reward to transform - - Returns: - The transformed reward + reward (Union[float, int, np.ndarray]): environment's reward """ - return self.f(reward) + return self.func(reward) + + +class ClipReward(TransformReward[ObsType, ActType], gym.utils.RecordConstructorArgs): + """Clips the rewards for an environment between an upper and lower bound. + + A vector version of the wrapper exists :class:`gymnasium.wrappers.vector.ClipReward`. + + Example: + >>> import gymnasium as gym + >>> from gymnasium.wrappers import ClipReward + >>> env = gym.make("CartPole-v1") + >>> env = ClipReward(env, 0, 0.5) + >>> _ = env.reset() + >>> _, rew, _, _, _ = env.step(1) + >>> rew + 0.5 + + Change logs: + * v1.0.0 - Initially added + """ + + def __init__( + self, + env: gym.Env[ObsType, ActType], + min_reward: float | np.ndarray | None = None, + max_reward: float | np.ndarray | None = None, + ): + """Initialize ClipRewards wrapper. + + Args: + env (Env): The environment to wrap + min_reward (Union[float, np.ndarray]): lower bound to apply + max_reward (Union[float, np.ndarray]): higher bound to apply + """ + if min_reward is None and max_reward is None: + raise InvalidBound("Both `min_reward` and `max_reward` cannot be None") + + elif max_reward is not None and min_reward is not None: + if np.any(max_reward - min_reward < 0): + raise InvalidBound( + f"Min reward ({min_reward}) must be smaller than max reward ({max_reward})" + ) + + gym.utils.RecordConstructorArgs.__init__( + self, min_reward=min_reward, max_reward=max_reward + ) + TransformReward.__init__( + self, env=env, func=lambda x: np.clip(x, a_min=min_reward, a_max=max_reward) + ) diff --git a/gymnasium/experimental/wrappers/utils.py b/gymnasium/wrappers/utils.py similarity index 96% rename from gymnasium/experimental/wrappers/utils.py rename to gymnasium/wrappers/utils.py index f119668de..fac8e1fde 100644 --- a/gymnasium/experimental/wrappers/utils.py +++ b/gymnasium/wrappers/utils.py @@ -28,10 +28,10 @@ class RunningMeanStd: """Tracks the mean, variance and count of values.""" # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm - def __init__(self, epsilon=1e-4, shape=()): + def __init__(self, epsilon=1e-4, shape=(), dtype=np.float64): """Tracks the mean, variance and count of values.""" - self.mean = np.zeros(shape, "float64") - self.var = np.ones(shape, "float64") + self.mean = np.zeros(shape, dtype=dtype) + self.var = np.ones(shape, dtype=dtype) self.count = epsilon def update(self, x): diff --git a/gymnasium/wrappers/vector/__init__.py b/gymnasium/wrappers/vector/__init__.py new file mode 100644 index 000000000..f832a9f06 --- /dev/null +++ b/gymnasium/wrappers/vector/__init__.py @@ -0,0 +1,106 @@ +"""Wrappers for vector environments.""" +# pyright: reportUnsupportedDunderAll=false +import importlib + +from gymnasium.wrappers.vector.common import RecordEpisodeStatistics +from gymnasium.wrappers.vector.dict_info_to_list import DictInfoToList +from gymnasium.wrappers.vector.stateful_observation import NormalizeObservation +from gymnasium.wrappers.vector.stateful_reward import NormalizeReward +from gymnasium.wrappers.vector.vectorize_action import ( + ClipAction, + RescaleAction, + TransformAction, + VectorizeTransformAction, +) +from gymnasium.wrappers.vector.vectorize_observation import ( + DtypeObservation, + FilterObservation, + FlattenObservation, + GrayscaleObservation, + RescaleObservation, + ReshapeObservation, + ResizeObservation, + TransformObservation, + VectorizeTransformObservation, +) +from gymnasium.wrappers.vector.vectorize_reward import ( + ClipReward, + TransformReward, + VectorizeTransformReward, +) + + +__all__ = [ + # --- Vector only wrappers + "VectorizeTransformObservation", + "VectorizeTransformAction", + "VectorizeTransformReward", + "DictInfoToList", + # --- Observation wrappers --- + "TransformObservation", + "FilterObservation", + "FlattenObservation", + "GrayscaleObservation", + "ResizeObservation", + "ReshapeObservation", + "RescaleObservation", + "DtypeObservation", + "NormalizeObservation", + # "RenderObservation", + # "TimeAwareObservation", + # "FrameStackObservation", + # "DelayObservation", + # --- Action Wrappers --- + "TransformAction", + "ClipAction", + "RescaleAction", + # --- Reward wrappers --- + "TransformReward", + "ClipReward", + "NormalizeReward", + # --- Common --- + "RecordEpisodeStatistics", + # --- Rendering --- + # "RenderCollection", + # "RecordVideo", + # "HumanRendering", + # --- Conversion --- + "JaxToNumpy", + "JaxToTorch", + "NumpyToTorch", +] + + +# As these wrappers requires `jax` or `torch`, they are loaded by runtime on users trying to access them +# to avoid `import jax` or `import torch` on `import gymnasium`. +_wrapper_to_class = { + # data converters + "JaxToNumpy": "jax_to_numpy", + "JaxToTorch": "jax_to_torch", + "NumpyToTorch": "numpy_to_torch", +} + + +def __getattr__(wrapper_name: str): + """Load a wrapper by name. + + This optimizes the loading of gymnasium wrappers by only loading the wrapper if it is used. + Errors will be raised if the wrapper does not exist or if the version is not the latest. + + Args: + wrapper_name: The name of a wrapper to load. + + Returns: + The specified wrapper. + + Raises: + AttributeError: If the wrapper does not exist. + DeprecatedWrapper: If the version is not the latest. + """ + # Check if the requested wrapper is in the _wrapper_to_class dictionary + if wrapper_name in _wrapper_to_class: + import_stmt = f"gymnasium.wrappers.vector.{_wrapper_to_class[wrapper_name]}" + module = importlib.import_module(import_stmt) + return getattr(module, wrapper_name) + + raise AttributeError(f"module {__name__!r} has no attribute {wrapper_name!r}") diff --git a/gymnasium/experimental/wrappers/vector/record_episode_statistics.py b/gymnasium/wrappers/vector/common.py similarity index 73% rename from gymnasium/experimental/wrappers/vector/record_episode_statistics.py rename to gymnasium/wrappers/vector/common.py index da5bb5105..0ad5883d6 100644 --- a/gymnasium/experimental/wrappers/vector/record_episode_statistics.py +++ b/gymnasium/wrappers/vector/common.py @@ -7,32 +7,18 @@ from collections import deque import numpy as np from gymnasium.core import ActType, ObsType -from gymnasium.experimental.vector.vector_env import ArrayType, VectorEnv, VectorWrapper +from gymnasium.vector.vector_env import ArrayType, VectorEnv, VectorWrapper -__all__ = ["RecordEpisodeStatisticsV0"] +__all__ = ["RecordEpisodeStatistics"] -class RecordEpisodeStatisticsV0(VectorWrapper): +class RecordEpisodeStatistics(VectorWrapper): """This wrapper will keep track of cumulative rewards and episode lengths. - At the end of an episode, the statistics of the episode will be added to ``info`` - using the key ``episode``. If using a vectorized environment also the key - ``_episode`` is used which indicates whether the env at the respective index has - the episode statistics. - - After the completion of an episode, ``info`` will look like this:: - - >>> info = { # doctest: +SKIP - ... ... - ... "episode": { - ... "r": "", - ... "l": "", - ... "t": "" - ... }, - ... } - - For a vectorized environments the output will be in the form of:: + At the end of any episode within the vectorized env, the statistics of the episode + will be added to ``info`` using the key ``episode``, and the ``_episode`` key + is used to indicate the environment index which has a terminated or truncated episode. >>> infos = { # doctest: +SKIP ... ... @@ -50,6 +36,30 @@ class RecordEpisodeStatisticsV0(VectorWrapper): Attributes: return_queue: The cumulative rewards of the last ``deque_size``-many episodes length_queue: The lengths of the last ``deque_size``-many episodes + + Example: + >>> from pprint import pprint + >>> import gymnasium as gym + >>> envs = gym.make_vec("CartPole-v1", num_envs=3) + >>> envs = RecordEpisodeStatistics(envs) + >>> obs, info = envs.reset(123) + >>> _ = envs.action_space.seed(123) + >>> end = False + >>> while not end: + ... obs, rew, term, trunc, info = envs.step(envs.action_space.sample()) + ... end = term.any() or trunc.any() + ... + >>> envs.close() + >>> pprint(info) # doctest: +SKIP + {'_episode': array([ True, False, False]), + '_final_info': array([ True, False, False]), + '_final_observation': array([ True, False, False]), + 'episode': {'l': array([11, 0, 0], dtype=int32), + 'r': array([11., 0., 0.], dtype=float32), + 't': array([0.007812, 0. , 0. ], dtype=float32)}, + 'final_info': array([{}, None, None], dtype=object), + 'final_observation': array([array([ 0.11448676, 0.9416149 , -0.20946532, -1.7619033 ], dtype=float32), + None, None], dtype=object)} """ def __init__(self, env: VectorEnv, deque_size: int = 100): diff --git a/gymnasium/wrappers/vector/dict_info_to_list.py b/gymnasium/wrappers/vector/dict_info_to_list.py new file mode 100644 index 000000000..64b908ca1 --- /dev/null +++ b/gymnasium/wrappers/vector/dict_info_to_list.py @@ -0,0 +1,153 @@ +"""Wrapper that converts the info format for vec envs into the list format.""" +from __future__ import annotations + +from typing import Any + +from gymnasium.core import ActType, ObsType +from gymnasium.vector.vector_env import ArrayType, VectorEnv, VectorWrapper + + +__all__ = ["DictInfoToList"] + + +class DictInfoToList(VectorWrapper): + """Converts infos of vectorized environments from ``dict`` to ``List[dict]``. + + This wrapper converts the info format of a + vector environment from a dictionary to a list of dictionaries. + This wrapper is intended to be used around vectorized + environments. If using other wrappers that perform + operation on info like `RecordEpisodeStatistics` this + need to be the outermost wrapper. + + i.e. ``DictInfoToList(RecordEpisodeStatistics(vector_env))`` + + Example: + >>> import numpy as np + >>> dict_info = { + ... "k": np.array([0., 0., 0.5, 0.3]), + ... "_k": np.array([False, False, True, True]) + ... } + ... + >>> list_info = [{}, {}, {"k": 0.5}, {"k": 0.3}] + + Example for vector environments: + >>> import numpy as np + >>> import gymnasium as gym + >>> from gymnasium.spaces import Dict, Box + >>> envs = gym.make_vec("CartPole-v1", num_envs=3) + >>> obs, info = envs.reset(seed=123) + >>> info + {} + >>> envs = DictInfoToList(envs) + >>> obs, info = envs.reset(seed=123) + >>> info + [{}, {}, {}] + + Another example for vector environments: + >>> import numpy as np + >>> import gymnasium as gym + >>> envs = gym.make_vec("HalfCheetah-v4", num_envs=3) + >>> _ = envs.reset(seed=123) + >>> _ = envs.action_space.seed(123) + >>> _, _, _, _, infos = envs.step(envs.action_space.sample()) + >>> infos + {'x_position': array([0.03332211, 0.10172355, 0.08920531]), '_x_position': array([ True, True, True]), 'x_velocity': array([-0.06296527, 0.89345848, 0.37710836]), '_x_velocity': array([ True, True, True]), 'reward_run': array([-0.06296527, 0.89345848, 0.37710836]), '_reward_run': array([ True, True, True]), 'reward_ctrl': array([-0.24503503, -0.21944423, -0.20672209]), '_reward_ctrl': array([ True, True, True])} + >>> envs = DictInfoToList(envs) + >>> _ = envs.reset(seed=123) + >>> _ = envs.action_space.seed(123) + >>> _, _, _, _, infos = envs.step(envs.action_space.sample()) + >>> infos + [{'x_position': 0.03332210900362942, 'x_velocity': -0.06296527291998533, 'reward_run': -0.06296527291998533, 'reward_ctrl': -0.2450350284576416}, {'x_position': 0.10172354684460168, 'x_velocity': 0.8934584807363618, 'reward_run': 0.8934584807363618, 'reward_ctrl': -0.21944422721862794}, {'x_position': 0.08920531470057845, 'x_velocity': 0.3771083596080768, 'reward_run': 0.3771083596080768, 'reward_ctrl': -0.20672209262847902}] + + Change logs: + * v0.24.0 - Initially added as ``VectorListInfo`` + * v1.0.0 - Renamed to ``DictInfoToList`` + """ + + def __init__(self, env: VectorEnv): + """This wrapper will convert the info into the list format. + + Args: + env (Env): The environment to apply the wrapper + """ + super().__init__(env) + + def step( + self, actions: ActType + ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, list[dict[str, Any]]]: + """Steps through the environment, convert dict info to list.""" + observation, reward, terminated, truncated, infos = self.env.step(actions) + list_info = self._convert_info_to_list(infos) + + return observation, reward, terminated, truncated, list_info + + def reset( + self, + *, + seed: int | list[int] | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[ObsType, list[dict[str, Any]]]: + """Resets the environment using kwargs.""" + obs, infos = self.env.reset(seed=seed, options=options) + list_info = self._convert_info_to_list(infos) + + return obs, list_info + + def _convert_info_to_list(self, infos: dict) -> list[dict[str, Any]]: + """Convert the dict info to list. + + Convert the dict info of the vectorized environment + into a list of dictionaries where the i-th dictionary + has the info of the i-th environment. + + Args: + infos (dict): info dict coming from the env. + + Returns: + list_info (list): converted info. + + """ + list_info = [{} for _ in range(self.num_envs)] + list_info = self._process_episode_statistics(infos, list_info) + for k in infos: + if k.startswith("_"): + continue + for i, has_info in enumerate(infos[f"_{k}"]): + if has_info: + list_info[i][k] = infos[k][i] + return list_info + + # todo - I think this function should be more general for any information + def _process_episode_statistics(self, infos: dict, list_info: list) -> list[dict]: + """Process episode statistics. + + `RecordEpisodeStatistics` wrapper add extra + information to the info. This information are in + the form of a dict of dict. This method process these + information and add them to the info. + `RecordEpisodeStatistics` info contains the keys + "r", "l", "t" which represents "cumulative reward", + "episode length", "elapsed time since instantiation of wrapper". + + Args: + infos (dict): infos coming from `RecordEpisodeStatistics`. + list_info (list): info of the current vectorized environment. + + Returns: + list_info (list): updated info. + + """ + episode_statistics = infos.pop("episode", False) + if not episode_statistics: + return list_info + + episode_statistics_mask = infos.pop("_episode") + for i, has_info in enumerate(episode_statistics_mask): + if has_info: + list_info[i]["episode"] = {} + list_info[i]["episode"]["r"] = episode_statistics["r"][i] + list_info[i]["episode"]["l"] = episode_statistics["l"][i] + list_info[i]["episode"]["t"] = episode_statistics["t"][i] + + return list_info diff --git a/gymnasium/experimental/wrappers/vector/jax_to_numpy.py b/gymnasium/wrappers/vector/jax_to_numpy.py similarity index 78% rename from gymnasium/experimental/wrappers/vector/jax_to_numpy.py rename to gymnasium/wrappers/vector/jax_to_numpy.py index 512a57f55..6db86fd0f 100644 --- a/gymnasium/experimental/wrappers/vector/jax_to_numpy.py +++ b/gymnasium/wrappers/vector/jax_to_numpy.py @@ -7,21 +7,26 @@ import jax.numpy as jnp from gymnasium.core import ActType, ObsType from gymnasium.error import DependencyNotInstalled -from gymnasium.experimental.vector import VectorEnv, VectorWrapper -from gymnasium.experimental.vector.vector_env import ArrayType -from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax +from gymnasium.vector import VectorEnv, VectorWrapper +from gymnasium.vector.vector_env import ArrayType +from gymnasium.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax -__all__ = ["JaxToNumpyV0"] +__all__ = ["JaxToNumpy"] -class JaxToNumpyV0(VectorWrapper): +class JaxToNumpy(VectorWrapper): """Wraps a jax vector environment so that it can be interacted with through numpy arrays. Notes: - A vectorized version of ``gymnasium.experimental.wrappers.JaxToNumpyV0`` + A vectorized version of ``gymnasium.wrappers.JaxToNumpy`` Actions must be provided as numpy arrays and observations, rewards, terminations and truncations will be returned as numpy arrays. + + Example: + >>> import gymnasium as gym # doctest: +SKIP + >>> envs = gym.make_vec("JaxEnv-vx", 3) # doctest: +SKIP + >>> envs = JaxToNumpy(envs) # doctest: +SKIP """ def __init__(self, env: VectorEnv): diff --git a/gymnasium/experimental/wrappers/vector/jax_to_torch.py b/gymnasium/wrappers/vector/jax_to_torch.py similarity index 80% rename from gymnasium/experimental/wrappers/vector/jax_to_torch.py rename to gymnasium/wrappers/vector/jax_to_torch.py index cca5feaaf..b55157514 100644 --- a/gymnasium/experimental/wrappers/vector/jax_to_torch.py +++ b/gymnasium/wrappers/vector/jax_to_torch.py @@ -4,22 +4,23 @@ from __future__ import annotations from typing import Any from gymnasium.core import ActType, ObsType -from gymnasium.experimental.vector import VectorEnv, VectorWrapper -from gymnasium.experimental.vector.vector_env import ArrayType -from gymnasium.experimental.wrappers.jax_to_torch import ( - Device, - jax_to_torch, - torch_to_jax, -) +from gymnasium.vector import VectorEnv, VectorWrapper +from gymnasium.vector.vector_env import ArrayType +from gymnasium.wrappers.jax_to_torch import Device, jax_to_torch, torch_to_jax -__all__ = ["JaxToTorchV0"] +__all__ = ["JaxToTorch"] -class JaxToTorchV0(VectorWrapper): +class JaxToTorch(VectorWrapper): """Wraps a Jax-based vector environment so that it can be interacted with through PyTorch Tensors. Actions must be provided as PyTorch Tensors and observations, rewards, terminations and truncations will be returned as PyTorch Tensors. + + Example: + >>> import gymnasium as gym # doctest: +SKIP + >>> envs = gym.make_vec("JaxEnv-vx", 3) # doctest: +SKIP + >>> envs = JaxToTorch(envs) # doctest: +SKIP """ def __init__(self, env: VectorEnv, device: Device | None = None): diff --git a/gymnasium/experimental/wrappers/vector/numpy_to_torch.py b/gymnasium/wrappers/vector/numpy_to_torch.py similarity index 66% rename from gymnasium/experimental/wrappers/vector/numpy_to_torch.py rename to gymnasium/wrappers/vector/numpy_to_torch.py index 2e717f17f..4a2c3ab75 100644 --- a/gymnasium/experimental/wrappers/vector/numpy_to_torch.py +++ b/gymnasium/wrappers/vector/numpy_to_torch.py @@ -4,20 +4,39 @@ from __future__ import annotations from typing import Any from gymnasium.core import ActType, ObsType -from gymnasium.experimental.vector import VectorEnv, VectorWrapper -from gymnasium.experimental.vector.vector_env import ArrayType -from gymnasium.experimental.wrappers.jax_to_torch import Device -from gymnasium.experimental.wrappers.numpy_to_torch import ( - numpy_to_torch, - torch_to_numpy, -) +from gymnasium.vector import VectorEnv, VectorWrapper +from gymnasium.vector.vector_env import ArrayType +from gymnasium.wrappers.jax_to_torch import Device +from gymnasium.wrappers.numpy_to_torch import numpy_to_torch, torch_to_numpy -__all__ = ["NumpyToTorchV0"] +__all__ = ["NumpyToTorch"] -class NumpyToTorchV0(VectorWrapper): - """Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors.""" +class NumpyToTorch(VectorWrapper): + """Wraps a numpy-based environment so that it can be interacted with through PyTorch Tensors. + + Example: + >>> import torch + >>> import gymnasium as gym + >>> from gymnasium.wrappers.vector import NumpyToTorch + >>> envs = gym.make_vec("CartPole-v1", 3) + >>> envs = NumpyToTorch(envs) + >>> obs, _ = envs.reset(seed=123) + >>> type(obs) + + >>> action = torch.tensor(envs.action_space.sample()) + >>> obs, reward, terminated, truncated, info = envs.step(action) + >>> envs.close() + >>> type(obs) + + >>> type(reward) + + >>> type(terminated) + + >>> type(truncated) + + """ def __init__(self, env: VectorEnv, device: Device | None = None): """Wrapper class to change inputs and outputs of environment to PyTorch tensors. diff --git a/gymnasium/wrappers/vector/stateful_observation.py b/gymnasium/wrappers/vector/stateful_observation.py new file mode 100644 index 000000000..1bfbf18c2 --- /dev/null +++ b/gymnasium/wrappers/vector/stateful_observation.py @@ -0,0 +1,111 @@ +"""A collection of stateful observation wrappers. + +* ``NormalizeObservation`` - Normalize the observations +""" +from __future__ import annotations + +import numpy as np + +import gymnasium as gym +from gymnasium.core import ObsType +from gymnasium.vector.vector_env import VectorEnv, VectorObservationWrapper +from gymnasium.wrappers.utils import RunningMeanStd + + +__all__ = ["NormalizeObservation"] + + +class NormalizeObservation(VectorObservationWrapper, gym.utils.RecordConstructorArgs): + """This wrapper will normalize observations s.t. each coordinate is centered with unit variance. + + The property `_update_running_mean` allows to freeze/continue the running mean calculation of the observation + statistics. If `True` (default), the `RunningMeanStd` will get updated every step and reset call. + If `False`, the calculated statistics are used but not updated anymore; this may be used during evaluation. + + Note: + The normalization depends on past trajectories and observations will not be normalized correctly if the wrapper was + newly instantiated or the policy was changed recently. + + Example without the normalize reward wrapper: + >>> import gymnasium as gym + >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync") + >>> obs, info = envs.reset(seed=123) + >>> _ = envs.action_space.seed(123) + >>> for _ in range(100): + ... obs, *_ = envs.step(envs.action_space.sample()) + >>> np.mean(obs) + -0.017698428 + >>> np.std(obs) + 0.62041104 + >>> envs.close() + + Example with the normalize reward wrapper: + >>> import gymnasium as gym + >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync") + >>> envs = NormalizeObservation(envs) + >>> obs, info = envs.reset(seed=123) + >>> _ = envs.action_space.seed(123) + >>> for _ in range(100): + ... obs, *_ = envs.step(envs.action_space.sample()) + >>> np.mean(obs) + -0.28381696 + >>> np.std(obs) + 1.21742 + >>> envs.close() + """ + + def __init__(self, env: VectorEnv, epsilon: float = 1e-8): + """This wrapper will normalize observations s.t. each coordinate is centered with unit variance. + + Args: + env (Env): The environment to apply the wrapper + epsilon: A stability parameter that is used when scaling the observations. + """ + gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon) + VectorObservationWrapper.__init__(self, env) + + self.obs_rms = RunningMeanStd( + shape=self.single_observation_space.shape, + dtype=self.single_observation_space.dtype, + ) + self.epsilon = epsilon + self._update_running_mean = True + + @property + def update_running_mean(self) -> bool: + """Property to freeze/continue the running mean calculation of the observation statistics.""" + return self._update_running_mean + + @update_running_mean.setter + def update_running_mean(self, setting: bool): + """Sets the property to freeze/continue the running mean calculation of the observation statistics.""" + self._update_running_mean = setting + + def vector_observation(self, observation: ObsType) -> ObsType: + """Defines the vector observation normalization function. + + Args: + observation: A vector observation from the environment + + Returns: + the normalized observation + """ + return self._normalize_observations(observation) + + def single_observation(self, observation: ObsType) -> ObsType: + """Defines the single observation normalization function. + + Args: + observation: A single observation from the environment + + Returns: + The normalized observation + """ + return self._normalize_observations(observation[None]) + + def _normalize_observations(self, observations: ObsType) -> ObsType: + if self._update_running_mean: + self.obs_rms.update(observations) + return (observations - self.obs_rms.mean) / np.sqrt( + self.obs_rms.var + self.epsilon + ) diff --git a/gymnasium/wrappers/vector/stateful_reward.py b/gymnasium/wrappers/vector/stateful_reward.py new file mode 100644 index 000000000..8b96ae3ad --- /dev/null +++ b/gymnasium/wrappers/vector/stateful_reward.py @@ -0,0 +1,115 @@ +"""A collection of wrappers for modifying the reward with an internal state. + +* ``NormalizeReward`` - Normalizes the rewards to a mean and standard deviation +""" +from __future__ import annotations + +from typing import Any, SupportsFloat + +import numpy as np + +import gymnasium as gym +from gymnasium.core import ActType, ObsType +from gymnasium.vector.vector_env import ArrayType, VectorEnv, VectorWrapper +from gymnasium.wrappers.utils import RunningMeanStd + + +__all__ = ["NormalizeReward"] + + +class NormalizeReward(VectorWrapper, gym.utils.RecordConstructorArgs): + r"""This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. + + The exponential moving average will have variance :math:`(1 - \gamma)^2`. + + The property `_update_running_mean` allows to freeze/continue the running mean calculation of the reward + statistics. If `True` (default), the `RunningMeanStd` will get updated every time `self.normalize()` is called. + If False, the calculated statistics are used but not updated anymore; this may be used during evaluation. + + Note: + The scaling depends on past trajectories and rewards will not be scaled correctly if the wrapper was newly + instantiated or the policy was changed recently. + + Example without the normalize reward wrapper: + >>> import gymnasium as gym + >>> import numpy as np + >>> envs = gym.make_vec("MountainCarContinuous-v0", 3) + >>> _ = envs.reset(seed=123) + >>> _ = envs.action_space.seed(123) + >>> episode_rewards = [] + >>> for _ in range(100): + ... observation, reward, *_ = envs.step(envs.action_space.sample()) + ... episode_rewards.append(reward) + ... + >>> envs.close() + >>> np.mean(episode_rewards) + -0.03359492141887935 + >>> np.std(episode_rewards) + 0.029028230434438706 + + Example with the normalize reward wrapper: + >>> import gymnasium as gym + >>> import numpy as np + >>> envs = gym.make_vec("MountainCarContinuous-v0", 3) + >>> envs = NormalizeReward(envs) + >>> _ = envs.reset(seed=123) + >>> _ = envs.action_space.seed(123) + >>> episode_rewards = [] + >>> for _ in range(100): + ... observation, reward, *_ = envs.step(envs.action_space.sample()) + ... episode_rewards.append(reward) + ... + >>> envs.close() + >>> np.mean(episode_rewards) + -0.1598639586606745 + >>> np.std(episode_rewards) + 0.27800309628058434 + """ + + def __init__( + self, + env: VectorEnv, + gamma: float = 0.99, + epsilon: float = 1e-8, + ): + """This wrapper will normalize immediate rewards s.t. their exponential moving average has a fixed variance. + + Args: + env (env): The environment to apply the wrapper + epsilon (float): A stability parameter + gamma (float): The discount factor that is used in the exponential moving average. + """ + gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon) + VectorWrapper.__init__(self, env) + + self.return_rms = RunningMeanStd(shape=()) + self.accumulated_reward: np.array = np.zeros((self.num_envs,), dtype=np.float32) + self.gamma = gamma + self.epsilon = epsilon + self._update_running_mean = True + + @property + def update_running_mean(self) -> bool: + """Property to freeze/continue the running mean calculation of the reward statistics.""" + return self._update_running_mean + + @update_running_mean.setter + def update_running_mean(self, setting: bool): + """Sets the property to freeze/continue the running mean calculation of the reward statistics.""" + self._update_running_mean = setting + + def step( + self, actions: ActType + ) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict[str, Any]]: + """Steps through the environment, normalizing the reward returned.""" + obs, reward, terminated, truncated, info = super().step(actions) + self.accumulated_reward = ( + self.accumulated_reward * self.gamma * (1 - terminated) + reward + ) + return obs, self.normalize(reward), terminated, truncated, info + + def normalize(self, reward: SupportsFloat): + """Normalizes the rewards with the running mean rewards and their variance.""" + if self._update_running_mean: + self.return_rms.update(self.accumulated_reward) + return reward / np.sqrt(self.return_rms.var + self.epsilon) diff --git a/gymnasium/wrappers/vector/vectorize_action.py b/gymnasium/wrappers/vector/vectorize_action.py new file mode 100644 index 000000000..63fefb5d5 --- /dev/null +++ b/gymnasium/wrappers/vector/vectorize_action.py @@ -0,0 +1,253 @@ +"""Vectorizes action wrappers to work for `VectorEnv`.""" +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Callable + +import numpy as np + +from gymnasium import Space +from gymnasium.core import ActType, Env +from gymnasium.vector import VectorActionWrapper, VectorEnv +from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate +from gymnasium.wrappers import transform_action + + +class TransformAction(VectorActionWrapper): + """Transforms an action via a function provided to the wrapper. + + The function :attr:`func` will be applied to all vector actions. + If the observations from :attr:`func` are outside the bounds of the ``env``'s action space, + provide an :attr:`action_space` which specifies the action space for the vectorized environment. + + Example - Without action transformation: + >>> import gymnasium as gym + >>> envs = gym.make_vec("MountainCarContinuous-v0", num_envs=3) + >>> _ = envs.action_space.seed(123) + >>> obs, info = envs.reset(seed=123) + >>> for _ in range(10): + ... obs, rew, term, trunc, info = envs.step(envs.action_space.sample()) + ... + >>> envs.close() + >>> obs + array([[-0.46553135, -0.00142543], + [-0.498371 , -0.00715587], + [-0.4651575 , -0.00624371]], dtype=float32) + + Example - With action transformation: + >>> import gymnasium as gym + >>> from gymnasium.spaces import Box + >>> def shrink_action(act): + ... return act * 0.3 + ... + >>> envs = gym.make_vec("MountainCarContinuous-v0", num_envs=3) + >>> new_action_space = Box(low=shrink_action(envs.action_space.low), high=shrink_action(envs.action_space.high)) + >>> envs = TransformAction(env=envs, func=shrink_action, action_space=new_action_space) + >>> _ = envs.action_space.seed(123) + >>> obs, info = envs.reset(seed=123) + >>> for _ in range(10): + ... obs, rew, term, trunc, info = envs.step(envs.action_space.sample()) + ... + >>> envs.close() + >>> obs + array([[-0.48468155, -0.00372536], + [-0.47599354, -0.00545912], + [-0.46543318, -0.00615723]], dtype=float32) + """ + + def __init__( + self, + env: VectorEnv, + func: Callable[[ActType], Any], + action_space: Space | None = None, + ): + """Constructor for the lambda action wrapper. + + Args: + env: The vector environment to wrap + func: A function that will transform an action. If this transformed action is outside the action space of ``env.action_space`` then provide an ``action_space``. + action_space: The action spaces of the wrapper, if None, then it is assumed the same as ``env.action_space``. + """ + super().__init__(env) + + if action_space is not None: + self.action_space = action_space + + self.func = func + + def actions(self, actions: ActType) -> ActType: + """Applies the :attr:`func` to the actions.""" + return self.func(actions) + + +class VectorizeTransformAction(VectorActionWrapper): + """Vectorizes a single-agent transform action wrapper for vector environments. + + Example - Without action transformation: + >>> import gymnasium as gym + >>> envs = gym.make_vec("MountainCarContinuous-v0", num_envs=3) + >>> _ = envs.action_space.seed(123) + >>> obs, info = envs.reset(seed=123) + >>> obs, rew, term, trunc, info = envs.step(envs.action_space.sample()) + >>> envs.close() + >>> obs + array([[-4.6343064e-01, 9.8971417e-05], + [-4.4488689e-01, -1.9375233e-03], + [-4.3118435e-01, -1.5342437e-03]], dtype=float32) + + Example - Adding a transform that applies a ReLU to the action: + >>> import gymnasium as gym + >>> from gymnasium.wrappers import TransformAction + >>> envs = gym.make_vec("MountainCarContinuous-v0", num_envs=3) + >>> envs = VectorizeTransformAction(envs, wrapper=TransformAction, func=lambda x: (x > 0.0) * x, action_space=envs.single_action_space) + >>> _ = envs.action_space.seed(123) + >>> obs, info = envs.reset(seed=123) + >>> obs, rew, term, trunc, info = envs.step(envs.action_space.sample()) + >>> envs.close() + >>> obs + array([[-4.6343064e-01, 9.8971417e-05], + [-4.4354835e-01, -5.9898634e-04], + [-4.3034542e-01, -6.9532328e-04]], dtype=float32) + """ + + class _SingleEnv(Env): + """Fake single-agent environment used for the single-agent wrapper.""" + + def __init__(self, action_space: Space): + """Constructor for the fake environment.""" + self.action_space = action_space + + def __init__( + self, + env: VectorEnv, + wrapper: type[transform_action.TransformAction], + **kwargs: Any, + ): + """Constructor for the vectorized lambda action wrapper. + + Args: + env: The vector environment to wrap + wrapper: The wrapper to vectorize + **kwargs: Arguments for the LambdaAction wrapper + """ + super().__init__(env) + + self.wrapper = wrapper(self._SingleEnv(self.env.single_action_space), **kwargs) + self.single_action_space = self.wrapper.action_space + self.action_space = batch_space(self.single_action_space, self.num_envs) + + self.same_out = self.action_space == self.env.action_space + self.out = create_empty_array(self.single_action_space, self.num_envs) + + def actions(self, actions: ActType) -> ActType: + """Applies the wrapper to each of the action. + + Args: + actions: The actions to apply the function to + + Returns: + The updated actions using the wrapper func + """ + if self.same_out: + return concatenate( + self.single_action_space, + tuple( + self.wrapper.func(action) + for action in iterate(self.action_space, actions) + ), + actions, + ) + else: + return deepcopy( + concatenate( + self.single_action_space, + tuple( + self.wrapper.func(action) + for action in iterate(self.env.action_space, actions) + ), + self.out, + ) + ) + + +class ClipAction(VectorizeTransformAction): + """Clip the continuous action within the valid :class:`Box` observation space bound. + + Example - Passing an out-of-bounds action to the environment to be clipped. + >>> import numpy as np + >>> import gymnasium as gym + >>> envs = gym.make_vec("MountainCarContinuous-v0", num_envs=3) + >>> envs = ClipAction(envs) + >>> _ = envs.action_space.seed(123) + >>> obs, info = envs.reset(seed=123) + >>> obs, rew, term, trunc, info = envs.step(np.array([5.0, -5.0, 2.0])) + >>> envs.close() + >>> obs + array([[-0.4624777 , 0.00105192], + [-0.44504836, -0.00209899], + [-0.42884544, 0.00080468]], dtype=float32) + """ + + def __init__(self, env: VectorEnv): + """Constructor for the Clip Action wrapper. + + Args: + env: The vector environment to wrap + """ + super().__init__(env, transform_action.ClipAction) + + +class RescaleAction(VectorizeTransformAction): + """Affinely rescales the continuous action space of the environment to the range [min_action, max_action]. + + Example - Without action scaling: + >>> import numpy as np + >>> import gymnasium as gym + >>> envs = gym.make_vec("MountainCarContinuous-v0", num_envs=3) + >>> _ = envs.action_space.seed(123) + >>> obs, info = envs.reset(seed=123) + >>> for _ in range(10): + ... obs, rew, term, trunc, info = envs.step(0.5 * np.ones((3, 1))) + ... + >>> envs.close() + >>> obs + array([[-0.44799727, 0.00266526], + [-0.4351738 , 0.00133522], + [-0.42683297, 0.00048403]], dtype=float32) + + Example - With action scaling: + >>> import numpy as np + >>> import gymnasium as gym + >>> envs = gym.make_vec("MountainCarContinuous-v0", num_envs=3) + >>> envs = RescaleAction(envs, 0.0, 1.0) + >>> _ = envs.action_space.seed(123) + >>> obs, info = envs.reset(seed=123) + >>> for _ in range(10): + ... obs, rew, term, trunc, info = envs.step(0.5 * np.ones((3, 1))) + ... + >>> envs.close() + >>> obs + array([[-0.48657528, -0.00395268], + [-0.47377947, -0.00529102], + [-0.46546045, -0.00614867]], dtype=float32) + """ + + def __init__( + self, + env: VectorEnv, + min_action: float | int | np.ndarray, + max_action: float | int | np.ndarray, + ): + """Initializes the :class:`RescaleAction` wrapper. + + Args: + env (Env): The vector environment to wrap + min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar. + max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar. + """ + super().__init__( + env, + transform_action.RescaleAction, + min_action=min_action, + max_action=max_action, + ) diff --git a/gymnasium/wrappers/vector/vectorize_observation.py b/gymnasium/wrappers/vector/vectorize_observation.py new file mode 100644 index 000000000..94cacbbe3 --- /dev/null +++ b/gymnasium/wrappers/vector/vectorize_observation.py @@ -0,0 +1,394 @@ +"""Vectorizes observation wrappers to works for `VectorEnv`.""" +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Callable, Sequence + +import numpy as np + +from gymnasium import Space +from gymnasium.core import Env, ObsType +from gymnasium.vector import VectorEnv, VectorObservationWrapper +from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate +from gymnasium.wrappers import transform_observation + + +class TransformObservation(VectorObservationWrapper): + """Transforms an observation via a function provided to the wrapper. + + This function allows the manual specification of the vector-observation function as well as the single-observation function. + This is desirable when, for example, it is possible to process vector observations in parallel or via other more optimized methods. + Otherwise, the ``VectorizeTransformObservation`` should be used instead, where only ``single_func`` needs to be defined. + + Example - Without observation transformation: + >>> import gymnasium as gym + >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync") + >>> obs, info = envs.reset(seed=123) + >>> obs + array([[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], + [ 0.02852531, 0.02858594, 0.0469136 , 0.02480598], + [ 0.03517495, -0.000635 , -0.01098382, -0.03203924]], + dtype=float32) + >>> envs.close() + + Example - With observation transformation: + >>> import gymnasium as gym + >>> from gymnasium.spaces import Box + >>> def scale_and_shift(obs): + ... return (obs - 1.0) * 2.0 + ... + >>> def vector_scale_and_shift(obs): + ... return (obs - 1.0) * 2.0 + ... + >>> import gymnasium as gym + >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync") + >>> new_obs_space = Box(low=envs.observation_space.low, high=envs.observation_space.high) + >>> envs = TransformObservation(envs, single_func=scale_and_shift, vector_func=vector_scale_and_shift) + >>> obs, info = envs.reset(seed=123) + >>> obs + array([[-1.9635296, -2.0892358, -2.055928 , -2.0631256], + [-1.9429494, -1.9428282, -1.9061728, -1.9503881], + [-1.9296501, -2.00127 , -2.0219676, -2.0640786]], dtype=float32) + >>> envs.close() + """ + + def __init__( + self, + env: VectorEnv, + vector_func: Callable[[ObsType], Any], + single_func: Callable[[ObsType], Any], + observation_space: Space | None = None, + ): + """Constructor for the transform observation wrapper. + + Args: + env: The vector environment to wrap + vector_func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``. + single_func: A function that will transform an individual observation, this function will be used for the final observation from the environment and is returned under ``info`` and not the normal observation. + observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``. + """ + super().__init__(env) + + if observation_space is not None: + self.observation_space = observation_space + + self.vector_func = vector_func + self.single_func = single_func + + def vector_observation(self, observation: ObsType) -> ObsType: + """Apply function to the vector observation.""" + return self.vector_func(observation) + + def single_observation(self, observation: ObsType) -> ObsType: + """Apply function to the single observation.""" + return self.single_func(observation) + + +class VectorizeTransformObservation(VectorObservationWrapper): + """Vectorizes a single-agent transform observation wrapper for vector environments. + + Most of the lambda observation wrappers for single agent environments have vectorized implementations, + it is advised that users simply use those instead via importing from `gymnasium.wrappers.vector...`. + The following example illustrate use-cases where a custom lambda observation wrapper is required. + + Example - The normal observation: + >>> import gymnasium as gym + >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync") + >>> obs, info = envs.reset(seed=123) + >>> envs.close() + >>> obs + array([[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], + [ 0.02852531, 0.02858594, 0.0469136 , 0.02480598], + [ 0.03517495, -0.000635 , -0.01098382, -0.03203924]], + dtype=float32) + + Example - Applying a custom lambda observation wrapper that duplicates the observation from the environment + >>> import numpy as np + >>> import gymnasium as gym + >>> from gymnasium.spaces import Box + >>> from gymnasium.wrappers import TransformObservation + >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync") + >>> old_space = envs.single_observation_space + >>> new_space = Box(low=np.array([old_space.low, old_space.low]), high=np.array([old_space.high, old_space.high])) + >>> envs = VectorizeTransformObservation(envs, wrapper=TransformObservation, func=lambda x: np.array([x, x]), observation_space=new_space) + >>> obs, info = envs.reset(seed=123) + >>> envs.close() + >>> obs + array([[[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], + [ 0.01823519, -0.0446179 , -0.02796401, -0.03156282]], + + [[ 0.02852531, 0.02858594, 0.0469136 , 0.02480598], + [ 0.02852531, 0.02858594, 0.0469136 , 0.02480598]], + + [[ 0.03517495, -0.000635 , -0.01098382, -0.03203924], + [ 0.03517495, -0.000635 , -0.01098382, -0.03203924]]], + dtype=float32) + """ + + class _SingleEnv(Env): + """Fake single-agent environment used for the single-agent wrapper.""" + + def __init__(self, observation_space: Space): + """Constructor for the fake environment.""" + self.observation_space = observation_space + + def __init__( + self, + env: VectorEnv, + wrapper: type[transform_observation.TransformObservation], + **kwargs: Any, + ): + """Constructor for the vectorized transform observation wrapper. + + Args: + env: The vector environment to wrap. + wrapper: The wrapper to vectorize + **kwargs: Keyword argument for the wrapper + """ + super().__init__(env) + + self.wrapper = wrapper( + self._SingleEnv(self.env.single_observation_space), **kwargs + ) + self.single_observation_space = self.wrapper.observation_space + self.observation_space = batch_space( + self.single_observation_space, self.num_envs + ) + + self.same_out = self.observation_space == self.env.observation_space + self.out = create_empty_array(self.single_observation_space, self.num_envs) + + def vector_observation(self, observation: ObsType) -> ObsType: + """Iterates over the vector observations applying the single-agent wrapper ``observation`` then concatenates the observations together again.""" + if self.same_out: + return concatenate( + self.single_observation_space, + tuple( + self.wrapper.func(obs) + for obs in iterate(self.observation_space, observation) + ), + observation, + ) + else: + return deepcopy( + concatenate( + self.single_observation_space, + tuple( + self.wrapper.func(obs) + for obs in iterate(self.env.observation_space, observation) + ), + self.out, + ) + ) + + def single_observation(self, observation: ObsType) -> ObsType: + """Transforms a single observation using the wrapper transformation function.""" + return self.wrapper.func(observation) + + +class FilterObservation(VectorizeTransformObservation): + """Vector wrapper for filtering dict or tuple observation spaces. + + Example - Create a vectorized environment with a Dict space to demonstrate how to filter keys: + >>> import numpy as np + >>> import gymnasium as gym + >>> from gymnasium.spaces import Dict, Box + >>> from gymnasium.wrappers import TransformObservation + >>> from gymnasium.wrappers.vector import VectorizeTransformObservation, FilterObservation + >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync") + >>> make_dict = lambda x: {"obs": x, "junk": np.array([0.0])} + >>> new_space = Dict({"obs": envs.single_observation_space, "junk": Box(low=-1.0, high=1.0)}) + >>> envs = VectorizeTransformObservation(env=envs, wrapper=TransformObservation, func=make_dict, observation_space=new_space) + >>> envs = FilterObservation(envs, ["obs"]) + >>> obs, info = envs.reset(seed=123) + >>> envs.close() + >>> obs + OrderedDict([('obs', array([[ 0.01823519, -0.0446179 , -0.02796401, -0.03156282], + [ 0.02852531, 0.02858594, 0.0469136 , 0.02480598], + [ 0.03517495, -0.000635 , -0.01098382, -0.03203924]], + dtype=float32))]) + """ + + def __init__(self, env: VectorEnv, filter_keys: Sequence[str | int]): + """Constructor for the filter observation wrapper. + + Args: + env: The vector environment to wrap + filter_keys: The subspaces to be included, use a list of strings or integers for ``Dict`` and ``Tuple`` spaces respectivesly + """ + super().__init__( + env, transform_observation.FilterObservation, filter_keys=filter_keys + ) + + +class FlattenObservation(VectorizeTransformObservation): + """Observation wrapper that flattens the observation. + + Example: + >>> import gymnasium as gym + >>> envs = gym.make_vec("CarRacing-v2", num_envs=3, vectorization_mode="sync") + >>> obs, info = envs.reset(seed=123) + >>> obs.shape + (3, 96, 96, 3) + >>> envs = FlattenObservation(envs) + >>> obs, info = envs.reset(seed=123) + >>> obs.shape + (3, 27648) + >>> envs.close() + """ + + def __init__(self, env: VectorEnv): + """Constructor for any environment's observation space that implements ``spaces.utils.flatten_space`` and ``spaces.utils.flatten``. + + Args: + env: The vector environment to wrap + """ + super().__init__(env, transform_observation.FlattenObservation) + + +class GrayscaleObservation(VectorizeTransformObservation): + """Observation wrapper that converts an RGB image to grayscale. + + Example: + >>> import gymnasium as gym + >>> envs = gym.make_vec("CarRacing-v2", num_envs=3, vectorization_mode="sync") + >>> obs, info = envs.reset(seed=123) + >>> obs.shape + (3, 96, 96, 3) + >>> envs = GrayscaleObservation(envs) + >>> obs, info = envs.reset(seed=123) + >>> obs.shape + (3, 96, 96) + >>> envs.close() + """ + + def __init__(self, env: VectorEnv, keep_dim: bool = False): + """Constructor for an RGB image based environments to make the image grayscale. + + Args: + env: The vector environment to wrap + keep_dim: If to keep the channel in the observation, if ``True``, ``obs.shape == 3`` else ``obs.shape == 2`` + """ + super().__init__( + env, transform_observation.GrayscaleObservation, keep_dim=keep_dim + ) + + +class ResizeObservation(VectorizeTransformObservation): + """Resizes image observations using OpenCV to shape. + + Example: + >>> import gymnasium as gym + >>> envs = gym.make_vec("CarRacing-v2", num_envs=3, vectorization_mode="sync") + >>> obs, info = envs.reset(seed=123) + >>> obs.shape + (3, 96, 96, 3) + >>> envs = ResizeObservation(envs, shape=(28, 28)) + >>> obs, info = envs.reset(seed=123) + >>> obs.shape + (3, 28, 28, 3) + >>> envs.close() + """ + + def __init__(self, env: VectorEnv, shape: tuple[int, ...]): + """Constructor that requires an image environment observation space with a shape. + + Args: + env: The vector environment to wrap + shape: The resized observation shape + """ + super().__init__(env, transform_observation.ResizeObservation, shape=shape) + + +class ReshapeObservation(VectorizeTransformObservation): + """Reshapes array based observations to shapes. + + Example: + >>> import gymnasium as gym + >>> envs = gym.make_vec("CarRacing-v2", num_envs=3, vectorization_mode="sync") + >>> obs, info = envs.reset(seed=123) + >>> obs.shape + (3, 96, 96, 3) + >>> envs = ReshapeObservation(envs, shape=(9216, 3)) + >>> obs, info = envs.reset(seed=123) + >>> obs.shape + (3, 9216, 3) + >>> envs.close() + """ + + def __init__(self, env: VectorEnv, shape: int | tuple[int, ...]): + """Constructor for env with Box observation space that has a shape product equal to the new shape product. + + Args: + env: The vector environment to wrap + shape: The reshaped observation space + """ + super().__init__(env, transform_observation.ReshapeObservation, shape=shape) + + +class RescaleObservation(VectorizeTransformObservation): + """Linearly rescales observation to between a minimum and maximum value. + + Example: + >>> import gymnasium as gym + >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync") + >>> obs, info = envs.reset(seed=123) + >>> obs.min() + -0.0446179 + >>> obs.max() + 0.0469136 + >>> envs = RescaleObservation(envs, min_obs=-5.0, max_obs=5.0) + >>> obs, info = envs.reset(seed=123) + >>> obs.min() + -0.33379582 + >>> obs.max() + 0.55998987 + >>> envs.close() + """ + + def __init__( + self, + env: VectorEnv, + min_obs: np.floating | np.integer | np.ndarray, + max_obs: np.floating | np.integer | np.ndarray, + ): + """Constructor that requires the env observation spaces to be a :class:`Box`. + + Args: + env: The vector environment to wrap + min_obs: The new minimum observation bound + max_obs: The new maximum observation bound + """ + super().__init__( + env, + transform_observation.RescaleObservation, + min_obs=min_obs, + max_obs=max_obs, + ) + + +class DtypeObservation(VectorizeTransformObservation): + """Observation wrapper for transforming the dtype of an observation. + + Example: + >>> import numpy as np + >>> import gymnasium as gym + >>> envs = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync") + >>> obs, info = envs.reset(seed=123) + >>> obs.dtype + dtype('float32') + >>> envs = DtypeObservation(envs, dtype=np.float64) + >>> obs, info = envs.reset(seed=123) + >>> obs.dtype + dtype('float64') + >>> envs.close() + """ + + def __init__(self, env: VectorEnv, dtype: Any): + """Constructor for Dtype observation wrapper. + + Args: + env: The vector environment to wrap + dtype: The new dtype of the observation + """ + super().__init__(env, transform_observation.DtypeObservation, dtype=dtype) diff --git a/gymnasium/wrappers/vector/vectorize_reward.py b/gymnasium/wrappers/vector/vectorize_reward.py new file mode 100644 index 000000000..b535d175a --- /dev/null +++ b/gymnasium/wrappers/vector/vectorize_reward.py @@ -0,0 +1,163 @@ +"""Vectorizes reward function to work with `VectorEnv`.""" +from __future__ import annotations + +from typing import Any, Callable + +import numpy as np + +from gymnasium import Env +from gymnasium.vector import VectorEnv, VectorRewardWrapper +from gymnasium.vector.vector_env import ArrayType +from gymnasium.wrappers import transform_reward + + +class TransformReward(VectorRewardWrapper): + """A reward wrapper that allows a custom function to modify the step reward. + + Example: + Without reward transformation: + >>> import gymnasium as gym + >>> envs = gym.make_vec("MountainCarContinuous-v0", num_envs=3) + >>> _ = envs.action_space.seed(123) + >>> obs, info = envs.reset(seed=123) + >>> obs, rew, term, trunc, info = envs.step(envs.action_space.sample()) + >>> envs.close() + >>> rew + array([-0.01330088, -0.07963027, -0.03127944]) + + With reward transformation: + >>> import gymnasium as gym + >>> from gymnasium.spaces import Box + >>> def scale_and_shift(rew): + ... return (rew - 1.0) * 2.0 + ... + >>> envs = gym.make_vec("MountainCarContinuous-v0", num_envs=3) + >>> envs = TransformReward(env=envs, func=scale_and_shift) + >>> _ = envs.action_space.seed(123) + >>> obs, info = envs.reset(seed=123) + >>> obs, rew, term, trunc, info = envs.step(envs.action_space.sample()) + >>> envs.close() + >>> obs + array([[-4.6343064e-01, 9.8971417e-05], + [-4.4488689e-01, -1.9375233e-03], + [-4.3118435e-01, -1.5342437e-03]], dtype=float32) + """ + + def __init__(self, env: VectorEnv, func: Callable[[ArrayType], ArrayType]): + """Initialize LambdaReward wrapper. + + Args: + env (Env): The vector environment to wrap + func: (Callable): The function to apply to reward + """ + super().__init__(env) + + self.func = func + + def rewards(self, reward: ArrayType) -> ArrayType: + """Apply function to reward.""" + return self.func(reward) + + +class VectorizeTransformReward(VectorRewardWrapper): + """Vectorizes a single-agent transform reward wrapper for vector environments. + + Example: + Without reward transformation: + >>> import gymnasium as gym + >>> envs = gym.make_vec("MountainCarContinuous-v0", num_envs=3) + >>> _ = envs.action_space.seed(123) + >>> obs, info = envs.reset(seed=123) + >>> obs, rew, term, trunc, info = envs.step(envs.action_space.sample()) + >>> envs.close() + >>> rew + array([-0.01330088, -0.07963027, -0.03127944]) + + Adding a transform that applies a ReLU to the reward: + >>> import gymnasium as gym + >>> from gymnasium.wrappers import TransformReward + >>> envs = gym.make_vec("MountainCarContinuous-v0", num_envs=3) + >>> envs = VectorizeTransformReward(envs, wrapper=TransformReward, func=lambda x: (x > 0.0) * x) + >>> _ = envs.action_space.seed(123) + >>> obs, info = envs.reset(seed=123) + >>> obs, rew, term, trunc, info = envs.step(envs.action_space.sample()) + >>> envs.close() + >>> rew + array([-0., -0., -0.]) + """ + + def __init__( + self, + env: VectorEnv, + wrapper: type[transform_reward.TransformReward], + **kwargs: Any, + ): + """Constructor for the vectorized lambda reward wrapper. + + Args: + env: The vector environment to wrap. + wrapper: The wrapper to vectorize + **kwargs: Keyword argument for the wrapper + """ + super().__init__(env) + + self.wrapper = wrapper(Env(), **kwargs) + + def rewards(self, reward: ArrayType) -> ArrayType: + """Iterates over the reward updating each with the wrapper func.""" + for i, r in enumerate(reward): + reward[i] = self.wrapper.func(r) + return reward + + +class ClipReward(VectorizeTransformReward): + """A wrapper that clips the rewards for an environment between an upper and lower bound. + + Example: + Without clipping rewards: + >>> import numpy as np + >>> import gymnasium as gym + >>> envs = gym.make_vec("MountainCarContinuous-v0", num_envs=3) + >>> _ = envs.action_space.seed(123) + >>> obs, info = envs.reset(seed=123) + >>> for _ in range(10): + ... obs, rew, term, trunc, info = envs.step(0.5 * np.ones((3, 1))) + ... + >>> envs.close() + >>> rew + array([-0.025, -0.025, -0.025]) + + With clipped rewards: + >>> import numpy as np + >>> import gymnasium as gym + >>> envs = gym.make_vec("MountainCarContinuous-v0", num_envs=3) + >>> envs = ClipReward(envs, 0.0, 2.0) + >>> _ = envs.action_space.seed(123) + >>> obs, info = envs.reset(seed=123) + >>> for _ in range(10): + ... obs, rew, term, trunc, info = envs.step(0.5 * np.ones((3, 1))) + ... + >>> envs.close() + >>> rew + array([0., 0., 0.]) + """ + + def __init__( + self, + env: VectorEnv, + min_reward: float | np.ndarray | None = None, + max_reward: float | np.ndarray | None = None, + ): + """Constructor for ClipReward wrapper. + + Args: + env: The vector environment to wrap + min_reward: The min reward for each step + max_reward: the max reward for each step + """ + super().__init__( + env, + transform_reward.ClipReward, + min_reward=min_reward, + max_reward=max_reward, + ) diff --git a/gymnasium/wrappers/vector_list_info.py b/gymnasium/wrappers/vector_list_info.py deleted file mode 100644 index 3cfc46104..000000000 --- a/gymnasium/wrappers/vector_list_info.py +++ /dev/null @@ -1,126 +0,0 @@ -"""Wrapper that converts the info format for vec envs into the list format.""" - -from typing import List - -import gymnasium as gym - - -class VectorListInfo(gym.Wrapper, gym.utils.RecordConstructorArgs): - """Converts infos of vectorized environments from dict to List[dict]. - - This wrapper converts the info format of a - vector environment from a dictionary to a list of dictionaries. - This wrapper is intended to be used around vectorized - environments. If using other wrappers that perform - operation on info like `RecordEpisodeStatistics` this - need to be the outermost wrapper. - - i.e. `VectorListInfo(RecordEpisodeStatistics(envs))` - - Example: - >>> # As dict: - >>> infos = { - ... "final_observation": "", - ... "_final_observation": "", - ... "final_info": "", - ... "_final_info": "", - ... "episode": { - ... "r": "", - ... "l": "", - ... "t": "" - ... }, - ... "_episode": "" - ... } - >>> # As list: - >>> infos = [ - ... { - ... "episode": {"r": "", "l": "", "t": ""}, - ... "final_observation": "", - ... "final_info": {}, - ... }, - ... ..., - ... ] - """ - - def __init__(self, env): - """This wrapper will convert the info into the list format. - - Args: - env (Env): The environment to apply the wrapper - """ - gym.utils.RecordConstructorArgs.__init__(self) - gym.Wrapper.__init__(self, env) - try: - self.get_wrapper_attr("is_vector_env") - except AttributeError: - assert False, "This wrapper can only be used in vectorized environments." - - def step(self, action): - """Steps through the environment, convert dict info to list.""" - observation, reward, terminated, truncated, infos = self.env.step(action) - list_info = self._convert_info_to_list(infos) - - return observation, reward, terminated, truncated, list_info - - def reset(self, **kwargs): - """Resets the environment using kwargs.""" - obs, infos = self.env.reset(**kwargs) - list_info = self._convert_info_to_list(infos) - return obs, list_info - - def _convert_info_to_list(self, infos: dict) -> List[dict]: - """Convert the dict info to list. - - Convert the dict info of the vectorized environment - into a list of dictionaries where the i-th dictionary - has the info of the i-th environment. - - Args: - infos (dict): info dict coming from the env. - - Returns: - list_info (list): converted info. - - """ - list_info = [{} for _ in range(self.num_envs)] - list_info = self._process_episode_statistics(infos, list_info) - for k in infos: - if k.startswith("_"): - continue - for i, has_info in enumerate(infos[f"_{k}"]): - if has_info: - list_info[i][k] = infos[k][i] - return list_info - - def _process_episode_statistics(self, infos: dict, list_info: list) -> List[dict]: - """Process episode statistics. - - `RecordEpisodeStatistics` wrapper add extra - information to the info. This information are in - the form of a dict of dict. This method process these - information and add them to the info. - `RecordEpisodeStatistics` info contains the keys - "r", "l", "t" which represents "cumulative reward", - "episode length", "elapsed time since instantiation of wrapper". - - Args: - infos (dict): infos coming from `RecordEpisodeStatistics`. - list_info (list): info of the current vectorized environment. - - Returns: - list_info (list): updated info. - - """ - episode_statistics = infos.pop("episode", False) - if not episode_statistics: - return list_info - - episode_statistics_mask = infos.pop("_episode") - for i, has_info in enumerate(episode_statistics_mask): - if has_info: - list_info[i]["episode"] = {} - list_info[i]["episode"]["r"] = episode_statistics["r"][i] - list_info[i]["episode"]["l"] = episode_statistics["l"][i] - list_info[i]["episode"]["t"] = episode_statistics["t"][i] - - return list_info diff --git a/pyproject.toml b/pyproject.toml index 208f9b03e..5da5af912 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,7 @@ all = [ testing = [ "pytest ==7.1.3", "scipy >= 1.7.3", + "dill>=0.3.7", ] [project.urls] diff --git a/tests/envs/functional/test_core.py b/tests/envs/functional/test_core.py index 001089a20..b56309535 100644 --- a/tests/envs/functional/test_core.py +++ b/tests/envs/functional/test_core.py @@ -2,7 +2,7 @@ from typing import Any, Dict, Optional import numpy as np -from gymnasium.experimental.functional import FuncEnv +from gymnasium.functional import FuncEnv class BasicTestEnv(FuncEnv): diff --git a/tests/envs/mujoco/__init__.py b/tests/envs/mujoco/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/envs/registration/test_env_spec.py b/tests/envs/registration/test_env_spec.py index bdc78de29..387d94704 100644 --- a/tests/envs/registration/test_env_spec.py +++ b/tests/envs/registration/test_env_spec.py @@ -1,6 +1,5 @@ """Test for the `EnvSpec`, in particular, a full integration with `EnvSpec`.""" -import pickle - +import dill as pickle import pytest import gymnasium as gym @@ -136,7 +135,7 @@ def test_env_spec_pprint(): reward_threshold=475.0 max_episode_steps=500 additional_wrappers=[ - name=TimeAwareObservation, kwargs={} + name=TimeAwareObservation, kwargs={'flatten': True, 'normalize_time': False, 'dict_time_key': 'time'} ]""" ) @@ -148,7 +147,7 @@ entry_point=gymnasium.envs.classic_control.cartpole:CartPoleEnv reward_threshold=475.0 max_episode_steps=500 additional_wrappers=[ - name=TimeAwareObservation, entry_point=gymnasium.wrappers.time_aware_observation:TimeAwareObservation, kwargs={} + name=TimeAwareObservation, entry_point=gymnasium.wrappers.stateful_observation:TimeAwareObservation, kwargs={'flatten': True, 'normalize_time': False, 'dict_time_key': 'time'} ]""" ) @@ -161,11 +160,9 @@ reward_threshold=475.0 nondeterministic=False max_episode_steps=500 order_enforce=True -autoreset=False disable_env_checker=False -applied_api_compatibility=False additional_wrappers=[ - name=TimeAwareObservation, kwargs={} + name=TimeAwareObservation, kwargs={'flatten': True, 'normalize_time': False, 'dict_time_key': 'time'} ]""" ) @@ -187,8 +184,6 @@ reward_threshold=475.0 nondeterministic=False max_episode_steps=500 order_enforce=True -autoreset=False disable_env_checker=False -applied_api_compatibility=False additional_wrappers=[]""" ) diff --git a/tests/envs/registration/test_make.py b/tests/envs/registration/test_make.py index 2d54f2ed6..9cf0072e2 100644 --- a/tests/envs/registration/test_make.py +++ b/tests/envs/registration/test_make.py @@ -14,15 +14,14 @@ from gymnasium.envs.classic_control import CartPoleEnv from gymnasium.error import NameNotFound from gymnasium.utils.env_checker import data_equivalence from gymnasium.wrappers import ( - AutoResetWrapper, HumanRendering, OrderEnforcing, + PassiveEnvChecker, TimeLimit, ) -from gymnasium.wrappers.env_checker import PassiveEnvChecker from tests.envs.registration.utils_envs import ArgumentEnv from tests.envs.utils import all_testing_env_specs -from tests.testing_env import GenericTestEnv, old_reset_func, old_step_func +from tests.testing_env import GenericTestEnv from tests.wrappers.utils import has_wrapper @@ -93,34 +92,17 @@ def test_max_episode_steps(register_parameter_envs): assert env.spec.max_episode_steps == 100 assert has_wrapper(env, TimeLimit) - -def test_autorest(register_parameter_envs): - """Test the `autoreset` parameter in `gym.make`.""" - for make_id in [ + # Override max_episode_step to prevent applying the wrapper + for env_id in [ "CartPole-v1", gym.spec("CartPole-v1"), - "AutoresetEnv-v0", - gym.spec("AutoresetEnv-v0"), + "NoMaxEpisodeStepsEnv-v0", + gym.spec("NoMaxEpisodeStepsEnv-v0"), ]: - env_spec = gym.spec(make_id) if isinstance(make_id, str) else make_id - - # Use the spec's value - env = gym.make(make_id) + env = gym.make(env_id, max_episode_steps=-1) assert env.spec is not None - assert env.spec.autoreset == env_spec.autoreset - assert has_wrapper(env, AutoResetWrapper) is env_spec.autoreset - - # Set autoreset is True - env = gym.make(make_id, autoreset=True) - assert has_wrapper(env, AutoResetWrapper) - assert env.spec is not None - assert env.spec.autoreset is True - - # Set autoreset is False - env = gym.make(make_id, autoreset=False) - assert has_wrapper(env, AutoResetWrapper) is False - assert env.spec is not None - assert env.spec.autoreset is False + assert env.spec.max_episode_steps is None + assert has_wrapper(env, TimeLimit) is False @pytest.mark.parametrize( @@ -158,44 +140,6 @@ def test_disable_env_checker( del gym.registry["DisableEnvCheckerEnv-v0"] -def test_apply_api_compatibility(register_parameter_envs): - """Test the `apply_api_compatibility` parameter for `gym.make`.""" - # Apply the environment compatibility and check it works as intended - for make_id in ["EnabledApplyApiComp-v0", gym.spec("EnabledApplyApiComp-v0")]: - env = gym.make(make_id) - assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility) - - # env has time limit of 3 enabling this test - env.reset() - assert len(env.step(env.action_space.sample())) == 5 - env.step(env.action_space.sample()) - _, _, termination, truncation, _ = env.step(env.action_space.sample()) - assert termination is False and truncation is True - - for make_id in ["DisabledApplyApiComp-v0", gym.spec("DisabledApplyApiComp-v0")]: - # Turn off the spec api compatibility - env = gym.make(make_id) - assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility) is False - env.reset() - with pytest.raises( - ValueError, - match=re.escape("not enough values to unpack (expected 5, got 4)"), - ): - env.step(env.action_space.sample()) - - # Apply the environment compatibility and check it works as intended - assert env.spec is not None - assert env.spec.apply_api_compatibility is False - env = gym.make(make_id, apply_api_compatibility=True) - assert isinstance(env.unwrapped, gym.wrappers.EnvCompatibility) - - env.reset() - assert len(env.step(env.action_space.sample())) == 5 - env.step(env.action_space.sample()) - _, _, termination, truncation, _ = env.step(env.action_space.sample()) - assert termination is False and truncation is True - - def test_order_enforcing(register_parameter_envs): """Checks that gym.make wrappers the environment with the OrderEnforcing wrapper.""" assert all(spec.order_enforce is False for spec in all_testing_env_specs) @@ -324,9 +268,9 @@ def test_make_kwargs(register_kwargs_env): } assert isinstance(env.unwrapped, ArgumentEnv) - assert env.arg1 == "arg1" - assert env.arg2 == "override_arg2" - assert env.arg3 == "override_arg3" + assert env.unwrapped.arg1 == "arg1" + assert env.unwrapped.arg2 == "override_arg2" + assert env.unwrapped.arg3 == "override_arg3" env.close() @@ -382,12 +326,12 @@ def test_make_with_env_spec(): # make with wrapper in env-creator gym.register( "CartPole-v3", - lambda: gym.wrappers.TimeAwareObservation(CartPoleEnv()), + lambda: gym.wrappers.NormalizeReward(CartPoleEnv()), disable_env_checker=True, order_enforce=False, ) env_4 = gym.make(gym.spec("CartPole-v3")) - assert isinstance(env_4, gym.wrappers.TimeAwareObservation) + assert isinstance(env_4, gym.wrappers.NormalizeReward) assert isinstance(env_4.env, CartPoleEnv) env_4.close() @@ -396,10 +340,10 @@ def test_make_with_env_spec(): lambda: CartPoleEnv(), disable_env_checker=True, order_enforce=False, - additional_wrappers=(gym.wrappers.TimeAwareObservation.wrapper_spec(),), + additional_wrappers=(gym.wrappers.NormalizeReward.wrapper_spec(),), ) env_5 = gym.make(gym.spec("CartPole-v4")) - assert isinstance(env_5, gym.wrappers.TimeAwareObservation) + assert isinstance(env_5, gym.wrappers.NormalizeReward) assert isinstance(env_5.env, CartPoleEnv) env_5.close() @@ -516,28 +460,12 @@ def register_parameter_envs(): gym.register( "NoMaxEpisodeStepsEnv-v0", lambda: GenericTestEnv(), max_episode_steps=None ) - gym.register("AutoresetEnv-v0", lambda: GenericTestEnv(), autoreset=True) - gym.register( - "EnabledApplyApiComp-v0", - lambda: GenericTestEnv(step_func=old_step_func, reset_func=old_reset_func), - apply_api_compatibility=True, - max_episode_steps=3, - ) - gym.register( - "DisabledApplyApiComp-v0", - lambda: GenericTestEnv(step_func=old_step_func, reset_func=old_reset_func), - apply_api_compatibility=False, - max_episode_steps=3, - ) gym.register("OrderlessEnv-v0", lambda: GenericTestEnv(), order_enforce=False) yield del gym.registry["NoMaxEpisodeStepsEnv-v0"] - del gym.registry["AutoresetEnv-v0"] - del gym.registry["EnabledApplyApiComp-v0"] - del gym.registry["DisabledApplyApiComp-v0"] del gym.registry["OrderlessEnv-v0"] diff --git a/tests/envs/registration/test_make_vec.py b/tests/envs/registration/test_make_vec.py index 693253be5..866c4f755 100644 --- a/tests/envs/registration/test_make_vec.py +++ b/tests/envs/registration/test_make_vec.py @@ -1,9 +1,12 @@ """Testing of the `gym.make_vec` function.""" +import re import pytest import gymnasium as gym -from gymnasium.experimental.vector import AsyncVectorEnv, SyncVectorEnv +from gymnasium.envs.classic_control import CartPoleEnv +from gymnasium.envs.classic_control.cartpole import CartPoleVectorEnv +from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv from gymnasium.wrappers import TimeLimit, TransformObservation from tests.wrappers.utils import has_wrapper @@ -11,25 +14,65 @@ from tests.wrappers.utils import has_wrapper def test_make_vec_env_id(): """Ensure that the `gym.make_vec` creates the right environment.""" env = gym.make_vec("CartPole-v1") - assert isinstance(env, AsyncVectorEnv) + assert isinstance(env, CartPoleVectorEnv) assert env.num_envs == 1 env.close() @pytest.mark.parametrize("num_envs", [1, 3, 10]) -def test_make_vec_num_envs(num_envs): +@pytest.mark.parametrize("vectorization_mode", ["vector_entry_point", "async", "sync"]) +def test_make_vec_num_envs(num_envs, vectorization_mode): """Test that the `gym.make_vec` num_envs parameter works.""" - env = gym.make_vec("CartPole-v1", num_envs=num_envs) + env = gym.make_vec( + "CartPole-v1", num_envs=num_envs, vectorization_mode=vectorization_mode + ) assert env.num_envs == num_envs env.close() def test_make_vec_vectorization_mode(): """Tests the `gym.make_vec` vectorization mode works.""" + # Test the default value for spec with and without `vector_entry_point` + env_spec = gym.spec("CartPole-v1") + assert env_spec is not None and env_spec.vector_entry_point is not None + env = gym.make_vec("CartPole-v1") + assert isinstance(env, CartPoleVectorEnv) + env.close() + + env_spec = gym.spec("Pendulum-v1") + assert env_spec is not None and env_spec.vector_entry_point is None + env = gym.make_vec("Pendulum-v1") + assert isinstance(env, SyncVectorEnv) + env.close() + + # Test `vector_entry_point` + env = gym.make_vec("CartPole-v1", vectorization_mode="vector_entry_point") + assert isinstance(env, CartPoleVectorEnv) + env.close() + + with pytest.raises( + gym.error.Error, + match=re.escape( + "Cannot create vectorized environment for Pendulum-v1 because it doesn't have a vector entry point defined." + ), + ): + gym.make_vec("Pendulum-v1", vectorization_mode="vector_entry_point") + + # Test `async` env = gym.make_vec("CartPole-v1", vectorization_mode="async") assert isinstance(env, AsyncVectorEnv) env.close() + gym.register("VecOnlyEnv-v0", vector_entry_point=CartPoleVectorEnv) + with pytest.raises( + gym.error.Error, + match=re.escape( + "Cannot create vectorized environment for VecOnlyEnv-v0 because it doesn't have an entry point defined." + ), + ): + gym.make_vec("VecOnlyEnv-v0", vectorization_mode="async") + del gym.registry["VecOnlyEnv-v0"] + env = gym.make_vec("CartPole-v1", vectorization_mode="sync") assert isinstance(env, SyncVectorEnv) env.close() @@ -56,10 +99,63 @@ def test_make_vec_wrappers(): "CartPole-v1", num_envs=2, vectorization_mode="sync", - wrappers=[lambda _env: TransformObservation(_env, lambda obs: obs * 2)], + wrappers=[ + lambda _env: TransformObservation( + _env, lambda obs: obs * 2, sub_env.observation_space + ) + ], ) # As asynchronous environment are inaccessible, synchronous vector must be used assert isinstance(env, SyncVectorEnv) assert all(has_wrapper(sub_env, TransformObservation) for sub_env in env.envs) env.close() + + +@pytest.mark.parametrize( + "env_id, kwargs", + ( + ("CartPole-v1", {}), + ("CartPole-v1", {"num_envs": 3}), + ("CartPole-v1", {"vectorization_mode": "sync"}), + ("CartPole-v1", {"vectorization_mode": "vector_entry_point"}), + ( + "CartPole-v1", + {"vector_kwargs": {"copy": False}, "vectorization_mode": "sync"}, + ), + ( + "CartPole-v1", + { + "wrappers": (gym.wrappers.TimeAwareObservation,), + "vectorization_mode": "sync", + }, + ), + ("CartPole-v1", {"render_mode": "rgb_array"}), + ), +) +def test_make_vec_with_spec(env_id: str, kwargs: dict): + envs = gym.make_vec(env_id, **kwargs) + assert envs.spec is not None + recreated_envs = gym.make_vec(envs.spec) + + # Assert equivalence + assert envs.spec == recreated_envs.spec + assert envs.num_envs == recreated_envs.num_envs + + assert envs.observation_space == recreated_envs.observation_space + assert envs.single_observation_space == recreated_envs.single_observation_space + assert envs.action_space == recreated_envs.action_space + assert envs.single_action_space == recreated_envs.single_action_space + + assert type(envs) == type(recreated_envs) + + envs.close() + recreated_envs.close() + + +def test_async_with_dynamically_registered_env(): + gym.register("TestEnv-v0", CartPoleEnv) + + gym.make_vec("TestEnv-v0", vectorization_mode="async") + + del gym.registry["TestEnv-v0"] diff --git a/tests/envs/test_action_dim_check.py b/tests/envs/test_action_dim_check.py index 48ee7f844..0e6699baa 100644 --- a/tests/envs/test_action_dim_check.py +++ b/tests/envs/test_action_dim_check.py @@ -124,7 +124,7 @@ def test_box_actions_out_of_bound(env: gym.Env): assert oob_action[i] > upper_bounds[i] oob_obs, _, _, _, _ = oob_env.step(oob_action) - assert np.alltrue(obs == oob_obs) + assert np.all(obs == oob_obs) if is_lower_bound: obs, _, _, _, _ = env.step( @@ -136,6 +136,6 @@ def test_box_actions_out_of_bound(env: gym.Env): assert oob_action[i] < lower_bounds[i] oob_obs, _, _, _, _ = oob_env.step(oob_action) - assert np.alltrue(obs == oob_obs) + assert np.all(obs == oob_obs) env.close() diff --git a/tests/envs/test_compatibility.py b/tests/envs/test_compatibility.py deleted file mode 100644 index 5446a11c7..000000000 --- a/tests/envs/test_compatibility.py +++ /dev/null @@ -1,190 +0,0 @@ -import re -from typing import Any, Dict, Optional, Tuple - -import numpy as np -import pytest -from packaging import version - -import gymnasium -from gymnasium.error import DependencyNotInstalled -from gymnasium.spaces import Discrete -from gymnasium.wrappers.compatibility import EnvCompatibility, LegacyEnv - - -try: - import gym -except ImportError: - gym = None - - -try: - import shimmy -except ImportError: - shimmy = None - - -class LegacyEnvExplicit(LegacyEnv, gymnasium.Env): - """Legacy env that explicitly implements the old API.""" - - observation_space = Discrete(1) - action_space = Discrete(1) - metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30} - - def __init__(self): - pass - - def reset(self): - return 0 - - def step(self, action): - return 0, 0, False, {} - - def render(self, mode="human"): - if mode == "human": - return - elif mode == "rgb_array": - return np.zeros((1, 1, 3), dtype=np.uint8) - - def close(self): - pass - - def seed(self, seed=None): - pass - - -class LegacyEnvImplicit(gymnasium.Env): - """Legacy env that implicitly implements the old API as a protocol.""" - - observation_space = Discrete(1) - action_space = Discrete(1) - metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30} - - def __init__(self): - pass - - def reset(self): # type: ignore - return 0 # type: ignore - - def step(self, action: Any) -> Tuple[int, float, bool, Dict]: - return 0, 0.0, False, {} - - def render(self, mode: Optional[str] = "human") -> Any: - if mode == "human": - return - elif mode == "rgb_array": - return np.zeros((1, 1, 3), dtype=np.uint8) - - def close(self): - pass - - def seed(self, seed: Optional[int] = None): - pass - - -def test_explicit(): - old_env = LegacyEnvExplicit() - assert isinstance(old_env, LegacyEnv) - env = EnvCompatibility(old_env, render_mode="rgb_array") - assert env.observation_space == Discrete(1) - assert env.action_space == Discrete(1) - assert env.reset() == (0, {}) - assert env.reset(seed=0, options={"some": "option"}) == (0, {}) - assert env.step(0) == (0, 0, False, False, {}) - assert env.render().shape == (1, 1, 3) - env.close() - - -def test_implicit(): - old_env = LegacyEnvImplicit() - assert isinstance(old_env, LegacyEnv) - env = EnvCompatibility(old_env, render_mode="rgb_array") - assert env.observation_space == Discrete(1) - assert env.action_space == Discrete(1) - assert env.reset() == (0, {}) - assert env.reset(seed=0, options={"some": "option"}) == (0, {}) - assert env.step(0) == (0, 0, False, False, {}) - assert env.render().shape == (1, 1, 3) - env.close() - - -def test_make_compatibility_in_spec(): - gymnasium.register( - id="LegacyTestEnv-v0", - entry_point=LegacyEnvExplicit, - apply_api_compatibility=True, - ) - env = gymnasium.make("LegacyTestEnv-v0", render_mode="rgb_array") - assert env.observation_space == Discrete(1) - assert env.action_space == Discrete(1) - assert env.reset() == (0, {}) - assert env.reset(seed=0, options={"some": "option"}) == (0, {}) - assert env.step(0) == (0, 0, False, False, {}) - img = env.render() - assert isinstance(img, np.ndarray) - assert img.shape == (1, 1, 3) # type: ignore - env.close() - del gymnasium.envs.registration.registry["LegacyTestEnv-v0"] - - -def test_make_compatibility_in_make(): - gymnasium.register(id="LegacyTestEnv-v0", entry_point=LegacyEnvExplicit) - env = gymnasium.make( - "LegacyTestEnv-v0", apply_api_compatibility=True, render_mode="rgb_array" - ) - assert env.observation_space == Discrete(1) - assert env.action_space == Discrete(1) - assert env.reset() == (0, {}) - assert env.reset(seed=0, options={"some": "option"}) == (0, {}) - assert env.step(0) == (0, 0, False, False, {}) - img = env.render() - assert isinstance(img, np.ndarray) - assert img.shape == (1, 1, 3) # type: ignore - env.close() - del gymnasium.envs.registration.registry["LegacyTestEnv-v0"] - - -@pytest.mark.parametrize( - "env", - ( - pytest.param( - "GymV21Environment-v0", - marks=pytest.mark.skipif( - gym is not None - and ( - version.parse(gym.version.VERSION) < version.parse("0.21.0") - or version.parse(gym.version.VERSION) >= version.parse("0.26.0") - ), - reason="Cannot test GymV21Environment-v0 compatibility env with gym < 0.21.0 or gym >= 0.26.0", - ), - ), - pytest.param( - "GymV26Environment-v0", - marks=pytest.mark.skipif( - gym is not None - and version.parse(gym.version.VERSION) < version.parse("0.26.0"), - reason="Cannot test GymV26Environment-v0 compatibility env with gym < 0.26.0", - ), - ), - ), -) -def test_shimmy_gym_compatibility(env): - assert gymnasium.spec(env) is not None - - if shimmy is None: - with pytest.raises( - ImportError, - match=re.escape( - "To use the gym compatibility environments, run `pip install shimmy[gym-v21]` or `pip install shimmy[gym-v26]`" - ), - ): - gymnasium.make(env) - elif gym is None: - with pytest.raises( - DependencyNotInstalled, - match=re.escape( - "No module named 'gym' (Hint: You need to install gym with `pip install gym` to use gym environments" - ), - ): - gymnasium.make(env, env_id="CartPole-v1") - else: - gymnasium.make(env, env_id="CartPole-v1") diff --git a/tests/envs/test_env_implementation.py b/tests/envs/test_env_implementation.py index 6de341046..312b4e78e 100644 --- a/tests/envs/test_env_implementation.py +++ b/tests/envs/test_env_implementation.py @@ -4,7 +4,7 @@ import numpy as np import pytest import gymnasium as gym -from gymnasium.envs.box2d import BipedalWalker +from gymnasium.envs.box2d import BipedalWalker, CarRacing from gymnasium.envs.box2d.lunar_lander import demo_heuristic_lander from gymnasium.envs.toy_text import TaxiEnv from gymnasium.envs.toy_text.frozen_lake import generate_random_map @@ -23,7 +23,7 @@ def test_carracing_domain_randomize(): CarRacing DomainRandomize should have different colours at every reset. However, it should have same colours when `options={"randomize": False}` is given to reset. """ - env = gym.make("CarRacing-v2", domain_randomize=True) + env: CarRacing = gym.make("CarRacing-v2", domain_randomize=True).unwrapped road_color = env.road_color bg_color = env.bg_color @@ -173,7 +173,7 @@ def test_customizable_resets(env_name: str, low_high: Optional[list]): else: low, high = low_high env.reset(options={"low": low, "high": high}) - assert np.all((env.state >= low) & (env.state <= high)) + assert np.all((env.unwrapped.state >= low) & (env.unwrapped.state <= high)) # Make sure we can take a step. env.step(env.action_space.sample()) diff --git a/tests/experimental/__init__.py b/tests/experimental/__init__.py deleted file mode 100644 index 203b13d74..000000000 --- a/tests/experimental/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Testing suite for ``gymnasium.experimental``.""" diff --git a/tests/experimental/vector/__init__.py b/tests/experimental/vector/__init__.py deleted file mode 100644 index 567e426a8..000000000 --- a/tests/experimental/vector/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Testing for `gymnasium.experimental.vector`.""" diff --git a/tests/experimental/vector/test_async_vector_env.py b/tests/experimental/vector/test_async_vector_env.py deleted file mode 100644 index 68c3585c5..000000000 --- a/tests/experimental/vector/test_async_vector_env.py +++ /dev/null @@ -1,329 +0,0 @@ -"""Test the `SyncVectorEnv` implementation.""" - -import re -from multiprocessing import TimeoutError - -import numpy as np -import pytest - -from gymnasium.error import ( - AlreadyPendingCallError, - ClosedEnvironmentError, - NoAsyncCallError, -) -from gymnasium.experimental.vector import AsyncVectorEnv -from gymnasium.spaces import Box, Discrete, MultiDiscrete, Tuple -from tests.experimental.vector.testing_utils import ( - CustomSpace, - make_custom_space_env, - make_env, - make_slow_env, -) - - -@pytest.mark.parametrize("shared_memory", [True, False]) -def test_create_async_vector_env(shared_memory): - """Test creating an async vector environment with or without shared memory.""" - env_fns = [make_env("CartPole-v1", i) for i in range(8)] - - env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) - assert env.num_envs == 8 - env.close() - - -@pytest.mark.parametrize("shared_memory", [True, False]) -def test_reset_async_vector_env(shared_memory): - """Test the reset of an sync vector environment with or without shared memory.""" - env_fns = [make_env("CartPole-v1", i) for i in range(8)] - - env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) - observations, infos = env.reset() - - env.close() - - assert isinstance(env.observation_space, Box) - assert isinstance(observations, np.ndarray) - assert observations.dtype == env.observation_space.dtype - assert observations.shape == (8,) + env.single_observation_space.shape - assert observations.shape == env.observation_space.shape - - try: - env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) - observations, infos = env.reset() - finally: - env.close() - - assert isinstance(env.observation_space, Box) - assert isinstance(observations, np.ndarray) - assert observations.dtype == env.observation_space.dtype - assert observations.shape == (8,) + env.single_observation_space.shape - assert observations.shape == env.observation_space.shape - assert isinstance(infos, dict) - assert all([isinstance(info, dict) for info in infos]) - - -@pytest.mark.parametrize("shared_memory", [True, False]) -@pytest.mark.parametrize("use_single_action_space", [True, False]) -def test_step_async_vector_env(shared_memory, use_single_action_space): - """Test the step async vector environment with and without shared memory.""" - env_fns = [make_env("CartPole-v1", i) for i in range(8)] - - env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) - observations = env.reset() - - assert isinstance(env.single_action_space, Discrete) - assert isinstance(env.action_space, MultiDiscrete) - - if use_single_action_space: - actions = [env.single_action_space.sample() for _ in range(8)] - else: - actions = env.action_space.sample() - observations, rewards, terminations, truncations, _ = env.step(actions) - - env.close() - - assert isinstance(env.observation_space, Box) - assert isinstance(observations, np.ndarray) - assert observations.dtype == env.observation_space.dtype - assert observations.shape == (8,) + env.single_observation_space.shape - assert observations.shape == env.observation_space.shape - - assert isinstance(rewards, np.ndarray) - assert isinstance(rewards[0], (float, np.floating)) - assert rewards.ndim == 1 - assert rewards.size == 8 - - assert isinstance(terminations, np.ndarray) - assert terminations.dtype == np.bool_ - assert terminations.ndim == 1 - assert terminations.size == 8 - - assert isinstance(truncations, np.ndarray) - assert truncations.dtype == np.bool_ - assert truncations.ndim == 1 - assert truncations.size == 8 - - -@pytest.mark.parametrize("shared_memory", [True, False]) -def test_call_async_vector_env(shared_memory): - """Test call with async vector environment.""" - env_fns = [ - make_env("CartPole-v1", i, render_mode="rgb_array_list") for i in range(4) - ] - - env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) - _ = env.reset() - images = env.call("render") - gravity = env.call("gravity") - - env.close() - - assert isinstance(images, tuple) - assert len(images) == 4 - for i in range(4): - assert len(images[i]) == 1 - assert isinstance(images[i][0], np.ndarray) - - assert isinstance(gravity, tuple) - assert len(gravity) == 4 - for i in range(4): - assert isinstance(gravity[i], float) - assert gravity[i] == 9.8 - - -@pytest.mark.parametrize("shared_memory", [True, False]) -def test_set_attr_async_vector_env(shared_memory): - """Test `set_attr_` for async vector environment with or without shared memory.""" - env_fns = [make_env("CartPole-v1", i) for i in range(4)] - - env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) - env.set_attr("gravity", [9.81, 3.72, 8.87, 1.62]) - gravity = env.get_attr("gravity") - assert gravity == (9.81, 3.72, 8.87, 1.62) - - env.close() - - -@pytest.mark.parametrize("shared_memory", [True, False]) -def test_copy_async_vector_env(shared_memory): - """Test observations are a copy of the true observation with and without shared memory.""" - env_fns = [make_env("CartPole-v1", i) for i in range(8)] - - # TODO, these tests do nothing, understand the purpose of the tests and fix them - env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=True) - observations, infos = env.reset() - observations[0] = 0 - - env.close() - - -@pytest.mark.parametrize("shared_memory", [True, False]) -def test_no_copy_async_vector_env(shared_memory): - """Test observation are not a copy of the true observation with and without shared memory.""" - env_fns = [make_env("CartPole-v1", i) for i in range(8)] - - # TODO, these tests do nothing, understand the purpose of the tests and fix them - env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=False) - observations, infos = env.reset() - observations[0] = 0 - - env.close() - - -@pytest.mark.parametrize("shared_memory", [True, False]) -def test_reset_timeout_async_vector_env(shared_memory): - """Test timeout error on reset with and without shared memory.""" - env_fns = [make_slow_env(0.3, i) for i in range(4)] - - env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) - with pytest.raises(TimeoutError): - env.reset_async() - env.reset_wait(timeout=0.1) - - env.close(terminate=True) - - -@pytest.mark.parametrize("shared_memory", [True, False]) -def test_step_timeout_async_vector_env(shared_memory): - """Test timeout error on step with and without shared memory.""" - env_fns = [make_slow_env(0.0, i) for i in range(4)] - - env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) - with pytest.raises(TimeoutError): - env.reset() - env.step_async(np.array([0.1, 0.1, 0.3, 0.1])) - observations, rewards, terminations, truncations, _ = env.step_wait(timeout=0.1) - env.close(terminate=True) - - -@pytest.mark.parametrize("shared_memory", [True, False]) -def test_reset_out_of_order_async_vector_env(shared_memory): - """Test reset being called out of order with and without shared memory.""" - env_fns = [make_env("CartPole-v1", i) for i in range(4)] - - env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) - with pytest.raises( - NoAsyncCallError, - match=re.escape( - "Calling `reset_wait` without any prior call to `reset_async`." - ), - ): - env.reset_wait() - - env.close(terminate=True) - - env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) - with pytest.raises( - AlreadyPendingCallError, - match=re.escape( - "Calling `reset_async` while waiting for a pending call to `step` to complete" - ), - ): - actions = env.action_space.sample() - env.reset() - env.step_async(actions) - env.reset_async() - - with pytest.warns( - UserWarning, - match=re.escape( - "Calling `close` while waiting for a pending call to `step` to complete." - ), - ): - env.close(terminate=True) - - -@pytest.mark.parametrize("shared_memory", [True, False]) -def test_step_out_of_order_async_vector_env(shared_memory): - """Test step out of order with and without shared memory.""" - env_fns = [make_env("CartPole-v1", i) for i in range(4)] - - env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) - with pytest.raises( - NoAsyncCallError, - match=re.escape("Calling `step_wait` without any prior call to `step_async`."), - ): - env.action_space.sample() - env.reset() - env.step_wait() - - env.close(terminate=True) - - env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) - with pytest.raises( - AlreadyPendingCallError, - match=re.escape( - "Calling `step_async` while waiting for a pending call to `reset` to complete" - ), - ): - actions = env.action_space.sample() - env.reset_async() - env.step_async(actions) - - with pytest.warns( - UserWarning, - match=re.escape( - "Calling `close` while waiting for a pending call to `reset` to complete." - ), - ): - env.close(terminate=True) - - -@pytest.mark.parametrize("shared_memory", [True, False]) -def test_already_closed_async_vector_env(shared_memory): - """Test the error if a function is called if environment is already closed.""" - env_fns = [make_env("CartPole-v1", i) for i in range(4)] - with pytest.raises(ClosedEnvironmentError): - env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) - env.close() - env.reset() - - -@pytest.mark.parametrize("shared_memory", [True, False]) -def test_check_spaces_async_vector_env(shared_memory): - """Test check spaces for async vector environment with and without shared memory.""" - # CartPole-v1 - observation_space: Box(4,), action_space: Discrete(2) - env_fns = [make_env("CartPole-v1", i) for i in range(8)] - # FrozenLake-v1 - Discrete(16), action_space: Discrete(4) - env_fns[1] = make_env("FrozenLake-v1", 1) - with pytest.raises(RuntimeError): - env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) - env.close(terminate=True) - - -def test_custom_space_async_vector_env(): - """Test custom spaces with async vector environment.""" - env_fns = [make_custom_space_env(i) for i in range(4)] - - env = AsyncVectorEnv(env_fns, shared_memory=False) - reset_observations, reset_infos = env.reset() - - assert isinstance(env.single_action_space, CustomSpace) - assert isinstance(env.action_space, Tuple) - - actions = ("action-2", "action-3", "action-5", "action-7") - step_observations, rewards, terminations, truncations, _ = env.step(actions) - - env.close() - - assert isinstance(env.single_observation_space, CustomSpace) - assert isinstance(env.observation_space, Tuple) - - assert isinstance(reset_observations, tuple) - assert reset_observations == ("reset", "reset", "reset", "reset") - - assert isinstance(step_observations, tuple) - assert step_observations == ( - "step(action-2)", - "step(action-3)", - "step(action-5)", - "step(action-7)", - ) - - -def test_custom_space_async_vector_env_shared_memory(): - """Test custom space with shared memory.""" - env_fns = [make_custom_space_env(i) for i in range(4)] - with pytest.raises(ValueError): - env = AsyncVectorEnv(env_fns, shared_memory=True) - env.close(terminate=True) diff --git a/tests/experimental/vector/test_sync_vector_env.py b/tests/experimental/vector/test_sync_vector_env.py deleted file mode 100644 index 50781966a..000000000 --- a/tests/experimental/vector/test_sync_vector_env.py +++ /dev/null @@ -1,187 +0,0 @@ -"""Test the `SyncVectorEnv` implementation.""" - -import numpy as np -import pytest - -from gymnasium.envs.registration import EnvSpec -from gymnasium.experimental.vector import SyncVectorEnv -from gymnasium.spaces import Box, Discrete, MultiDiscrete, Tuple -from tests.envs.utils import all_testing_env_specs -from tests.vector.utils import ( - CustomSpace, - assert_rng_equal, - make_custom_space_env, - make_env, -) - - -def test_create_sync_vector_env(): - """Tests creating the sync vector environment.""" - env_fns = [make_env("FrozenLake-v1", i) for i in range(8)] - env = SyncVectorEnv(env_fns) - env.close() - - assert env.num_envs == 8 - - -def test_reset_sync_vector_env(): - """Tests sync vector `reset` function.""" - env_fns = [make_env("CartPole-v1", i) for i in range(8)] - env = SyncVectorEnv(env_fns) - observations, infos = env.reset() - env.close() - - assert isinstance(env.observation_space, Box) - assert isinstance(observations, np.ndarray) - assert observations.dtype == env.observation_space.dtype - assert observations.shape == (8,) + env.single_observation_space.shape - assert observations.shape == env.observation_space.shape - - del observations - - -@pytest.mark.parametrize("use_single_action_space", [True, False]) -def test_step_sync_vector_env(use_single_action_space): - """Test sync vector `steps` function.""" - env_fns = [make_env("FrozenLake-v1", i) for i in range(8)] - - env = SyncVectorEnv(env_fns) - observations = env.reset() - - assert isinstance(env.single_action_space, Discrete) - assert isinstance(env.action_space, MultiDiscrete) - - if use_single_action_space: - actions = [env.single_action_space.sample() for _ in range(8)] - else: - actions = env.action_space.sample() - observations, rewards, terminateds, truncateds, _ = env.step(actions) - - env.close() - - assert isinstance(env.observation_space, MultiDiscrete) - assert isinstance(observations, np.ndarray) - assert observations.dtype == env.observation_space.dtype - assert observations.shape == (8,) + env.single_observation_space.shape - assert observations.shape == env.observation_space.shape - - assert isinstance(rewards, np.ndarray) - assert isinstance(rewards[0], (float, np.floating)) - assert rewards.ndim == 1 - assert rewards.size == 8 - - assert isinstance(terminateds, np.ndarray) - assert terminateds.dtype == np.bool_ - assert terminateds.ndim == 1 - assert terminateds.size == 8 - - assert isinstance(truncateds, np.ndarray) - assert truncateds.dtype == np.bool_ - assert truncateds.ndim == 1 - assert truncateds.size == 8 - - -def test_call_sync_vector_env(): - """Test sync vector `call` on sub-environments.""" - env_fns = [ - make_env("CartPole-v1", i, render_mode="rgb_array_list") for i in range(4) - ] - - env = SyncVectorEnv(env_fns) - _ = env.reset() - images = env.call("render") - gravity = env.call("gravity") - - env.close() - - assert isinstance(images, tuple) - assert len(images) == 4 - for i in range(4): - assert len(images[i]) == 1 - assert isinstance(images[i][0], np.ndarray) - - assert isinstance(gravity, tuple) - assert len(gravity) == 4 - for i in range(4): - assert isinstance(gravity[i], float) - assert gravity[i] == 9.8 - - -def test_set_attr_sync_vector_env(): - """Test sync vector `set_attr` function.""" - env_fns = [make_env("CartPole-v1", i) for i in range(4)] - - env = SyncVectorEnv(env_fns) - env.set_attr("gravity", [9.81, 3.72, 8.87, 1.62]) - gravity = env.get_attr("gravity") - assert gravity == (9.81, 3.72, 8.87, 1.62) - - env.close() - - -def test_check_spaces_sync_vector_env(): - """Tests the sync vector `check_spaces` function.""" - # CartPole-v1 - observation_space: Box(4,), action_space: Discrete(2) - env_fns = [make_env("CartPole-v1", i) for i in range(8)] - # FrozenLake-v1 - Discrete(16), action_space: Discrete(4) - env_fns[1] = make_env("FrozenLake-v1", 1) - with pytest.raises(RuntimeError): - env = SyncVectorEnv(env_fns) - env.close() - - -def test_custom_space_sync_vector_env(): - """Test the use of custom spaces with sync vector environment.""" - env_fns = [make_custom_space_env(i) for i in range(4)] - - env = SyncVectorEnv(env_fns) - reset_observations, infos = env.reset() - - assert isinstance(env.single_action_space, CustomSpace) - assert isinstance(env.action_space, Tuple) - - actions = ("action-2", "action-3", "action-5", "action-7") - step_observations, rewards, terminateds, truncateds, _ = env.step(actions) - - env.close() - - assert isinstance(env.single_observation_space, CustomSpace) - assert isinstance(env.observation_space, Tuple) - - assert isinstance(reset_observations, tuple) - assert reset_observations == ("reset", "reset", "reset", "reset") - - assert isinstance(step_observations, tuple) - assert step_observations == ( - "step(action-2)", - "step(action-3)", - "step(action-5)", - "step(action-7)", - ) - - -def test_sync_vector_env_seed(): - """Test seeding for sync vector environments.""" - env = make_env("BipedalWalker-v3", seed=123)() - sync_vector_env = SyncVectorEnv([make_env("BipedalWalker-v3", seed=123)]) - - assert_rng_equal(env.action_space.np_random, sync_vector_env.action_space.np_random) - for _ in range(100): - env_action = env.action_space.sample() - vector_action = sync_vector_env.action_space.sample() - assert np.all(env_action == vector_action) - - -@pytest.mark.parametrize( - "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs] -) -def test_sync_vector_determinism(spec: EnvSpec, seed: int = 123, n: int = 3): - """Check that for all environments, the sync vector envs produce the same action samples using the same seeds.""" - env_1 = SyncVectorEnv([make_env(spec.id, seed=seed) for _ in range(n)]) - env_2 = SyncVectorEnv([make_env(spec.id, seed=seed) for _ in range(n)]) - assert_rng_equal(env_1.action_space.np_random, env_2.action_space.np_random) - - for _ in range(100): - env_1_samples = env_1.action_space.sample() - env_2_samples = env_2.action_space.sample() - assert np.all(env_1_samples == env_2_samples) diff --git a/tests/experimental/vector/test_vector_env.py b/tests/experimental/vector/test_vector_env.py deleted file mode 100644 index 685aada83..000000000 --- a/tests/experimental/vector/test_vector_env.py +++ /dev/null @@ -1,126 +0,0 @@ -"""Test vector environment implementations.""" - -from functools import partial - -import numpy as np -import pytest - -from gymnasium.experimental.vector import AsyncVectorEnv, SyncVectorEnv -from gymnasium.spaces import Discrete -from tests.testing_env import GenericTestEnv -from tests.vector.utils import make_env - - -@pytest.mark.parametrize("shared_memory", [True, False]) -def test_vector_env_equal(shared_memory): - """Test that vector environment are equal for both async and sync variants.""" - env_fns = [make_env("CartPole-v1", i) for i in range(4)] - num_steps = 100 - - async_env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) - sync_env = SyncVectorEnv(env_fns) - - assert async_env.num_envs == sync_env.num_envs - assert async_env.observation_space == sync_env.observation_space - assert async_env.single_observation_space == sync_env.single_observation_space - assert async_env.action_space == sync_env.action_space - assert async_env.single_action_space == sync_env.single_action_space - - async_observations, async_infos = async_env.reset(seed=0) - sync_observations, sync_infos = sync_env.reset(seed=0) - assert np.all(async_observations == sync_observations) - - for _ in range(num_steps): - actions = async_env.action_space.sample() - assert actions in sync_env.action_space - - ( - async_observations, - async_rewards, - async_terminations, - async_truncations, - async_infos, - ) = async_env.step(actions) - ( - sync_observations, - sync_rewards, - sync_terminations, - sync_truncations, - sync_infos, - ) = sync_env.step(actions) - - if any(sync_terminations) or any(sync_truncations): - assert "final_observation" in async_infos - assert "_final_observation" in async_infos - assert "final_observation" in sync_infos - assert "_final_observation" in sync_infos - - assert np.all(async_observations == sync_observations) - assert np.all(async_rewards == sync_rewards) - assert np.all(async_terminations == sync_terminations) - assert np.all(async_truncations == sync_truncations) - - async_env.close() - sync_env.close() - - -@pytest.mark.parametrize( - "vectoriser", - ( - SyncVectorEnv, - partial(AsyncVectorEnv, shared_memory=True), - partial(AsyncVectorEnv, shared_memory=False), - ), - ids=["Sync", "Async with shared memory", "Async without shared memory"], -) -def test_final_obs_info(vectoriser): - """Tests that the vector environments correctly return the final observation and info.""" - - def reset_fn(self, seed=None, options=None): - return 0, {"reset": True} - - def thunk(): - return GenericTestEnv( - action_space=Discrete(4), - observation_space=Discrete(4), - reset_func=reset_fn, - step_func=lambda self, action: ( - action if action < 3 else 0, - 0, - action >= 3, - False, - {"action": action}, - ), - ) - - env = vectoriser([thunk]) - obs, info = env.reset() - assert obs == np.array([0]) and info == { - "reset": np.array([True]), - "_reset": np.array([True]), - } - - obs, _, termination, _, info = env.step([1]) - assert ( - obs == np.array([1]) - and termination == np.array([False]) - and info == {"action": np.array([1]), "_action": np.array([True])} - ) - - obs, _, termination, _, info = env.step([2]) - assert ( - obs == np.array([2]) - and termination == np.array([False]) - and info == {"action": np.array([2]), "_action": np.array([True])} - ) - - obs, _, termination, _, info = env.step([3]) - assert ( - obs == np.array([0]) - and termination == np.array([True]) - and info["reset"] == np.array([True]) - ) - assert "final_observation" in info and "final_info" in info - assert info["final_observation"] == np.array([0]) and info["final_info"] == { - "action": 3 - } diff --git a/tests/experimental/vector/test_vector_env_info.py b/tests/experimental/vector/test_vector_env_info.py deleted file mode 100644 index 9050146c9..000000000 --- a/tests/experimental/vector/test_vector_env_info.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Test the vector environment information.""" -import numpy as np -import pytest - -import gymnasium as gym -from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv -from tests.vector.utils import make_env - - -ENV_ID = "CartPole-v1" -NUM_ENVS = 3 -ENV_STEPS = 50 -SEED = 42 - - -@pytest.mark.parametrize("vectorization_mode", ["async", "sync"]) -def test_vector_env_info(vectorization_mode: str): - """Test vector environment info for different vectorization modes.""" - env = gym.make_vec( - ENV_ID, - num_envs=NUM_ENVS, - vectorization_mode=vectorization_mode, - ) - env.reset(seed=SEED) - for _ in range(ENV_STEPS): - env.action_space.seed(SEED) - action = env.action_space.sample() - _, _, terminateds, truncateds, infos = env.step(action) - if any(terminateds) or any(truncateds): - assert len(infos["final_observation"]) == NUM_ENVS - assert len(infos["_final_observation"]) == NUM_ENVS - - assert isinstance(infos["final_observation"], np.ndarray) - assert isinstance(infos["_final_observation"], np.ndarray) - - for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)): - if terminated or truncated: - assert infos["_final_observation"][i] - else: - assert not infos["_final_observation"][i] - assert infos["final_observation"][i] is None - - -@pytest.mark.parametrize("concurrent_ends", [1, 2, 3]) -def test_vector_env_info_concurrent_termination(concurrent_ends): - """Test the vector environment information works with concurrent termination.""" - # envs that need to terminate together will have the same action - actions = [0] * concurrent_ends + [1] * (NUM_ENVS - concurrent_ends) - envs = [make_env(ENV_ID, SEED) for _ in range(NUM_ENVS)] - envs = SyncVectorEnv(envs) - - for _ in range(ENV_STEPS): - _, _, terminateds, truncateds, infos = envs.step(actions) - if any(terminateds) or any(truncateds): - for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)): - if i < concurrent_ends: - assert terminated or truncated - assert infos["_final_observation"][i] - else: - assert not infos["_final_observation"][i] - assert infos["final_observation"][i] is None - return diff --git a/tests/experimental/vector/utils/__init__.py b/tests/experimental/vector/utils/__init__.py deleted file mode 100644 index bc83cd7d0..000000000 --- a/tests/experimental/vector/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Module for testing `gymnasium.experimental.vector.utils` functions.""" diff --git a/tests/experimental/wrappers/__init__.py b/tests/experimental/wrappers/__init__.py deleted file mode 100644 index a100571ed..000000000 --- a/tests/experimental/wrappers/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Experimental wrapper module.""" diff --git a/tests/experimental/wrappers/test_atari_preprocessing.py b/tests/experimental/wrappers/test_atari_preprocessing.py deleted file mode 100644 index fa0aac8d2..000000000 --- a/tests/experimental/wrappers/test_atari_preprocessing.py +++ /dev/null @@ -1 +0,0 @@ -"""Test suite for AtariPreprocessingV0.""" diff --git a/tests/experimental/wrappers/test_autoreset.py b/tests/experimental/wrappers/test_autoreset.py deleted file mode 100644 index 948815732..000000000 --- a/tests/experimental/wrappers/test_autoreset.py +++ /dev/null @@ -1 +0,0 @@ -"""Test suite for AutoresetV0.""" diff --git a/tests/experimental/wrappers/test_clip_action.py b/tests/experimental/wrappers/test_clip_action.py deleted file mode 100644 index 2f80a1183..000000000 --- a/tests/experimental/wrappers/test_clip_action.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Test suite for ClipActionV0.""" - -import numpy as np - -from gymnasium.experimental.wrappers import ClipActionV0 -from gymnasium.spaces import Box -from tests.experimental.wrappers.utils import record_action_step -from tests.testing_env import GenericTestEnv - - -def test_clip_action_wrapper(): - """Test that the action is correctly clipped to the base environment action space.""" - env = GenericTestEnv( - action_space=Box(np.array([0, 0, 3]), np.array([1, 2, 4])), - step_func=record_action_step, - ) - wrapped_env = ClipActionV0(env) - - sampled_action = np.array([-1, 5, 3.5], dtype=np.float32) - assert sampled_action not in env.action_space - assert sampled_action in wrapped_env.action_space - - _, _, _, _, info = wrapped_env.step(sampled_action) - assert np.all(info["action"] in env.action_space) - assert np.all(info["action"] == np.array([0, 2, 3.5])) diff --git a/tests/experimental/wrappers/test_clip_reward.py b/tests/experimental/wrappers/test_clip_reward.py deleted file mode 100644 index b88290daf..000000000 --- a/tests/experimental/wrappers/test_clip_reward.py +++ /dev/null @@ -1,67 +0,0 @@ -"""Test suite for ClipRewardV0.""" -import numpy as np -import pytest - -import gymnasium as gym -from gymnasium.error import InvalidBound -from gymnasium.experimental.wrappers import ClipRewardV0 -from tests.envs.test_envs import SEED -from tests.experimental.wrappers.test_lambda_rewards import ( - DISCRETE_ACTION, - ENV_ID, - NUM_ENVS, -) - - -@pytest.mark.parametrize( - ("lower_bound", "upper_bound", "expected_reward"), - [(None, 0.5, 0.5), (0, None, 1), (0, 0.5, 0.5)], -) -def test_clip_reward(lower_bound, upper_bound, expected_reward): - """Test reward clipping. - - Test if reward is correctly clipped accordingly to the input args. - """ - env = gym.make(ENV_ID) - env = ClipRewardV0(env, lower_bound, upper_bound) - env.reset(seed=SEED) - _, rew, _, _, _ = env.step(DISCRETE_ACTION) - - assert rew == expected_reward - - -@pytest.mark.parametrize( - ("lower_bound", "upper_bound", "expected_reward"), - [(None, 0.5, 0.5), (0, None, 1), (0, 0.5, 0.5)], -) -def test_clip_reward_within_vector(lower_bound, upper_bound, expected_reward): - """Test reward clipping in vectorized environment. - - Test if reward is correctly clipped accordingly to the input args in a vectorized environment. - """ - actions = [DISCRETE_ACTION for _ in range(NUM_ENVS)] - - env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS) - env = ClipRewardV0(env, lower_bound, upper_bound) - env.reset(seed=SEED) - - _, rew, _, _, _ = env.step(actions) - - assert np.alltrue(rew == expected_reward) - - -@pytest.mark.parametrize( - ("lower_bound", "upper_bound"), - [(None, None), (1, -1), (np.array([1, 1]), np.array([0, 0]))], -) -def test_clip_reward_incorrect_params(lower_bound, upper_bound): - """Test reward clipping with incorrect params. - - Test whether passing wrong params to clip_rewards correctly raise an exception. - clip_rewards should raise an exception if, both low and upper bound of reward are `None` - or if upper bound is lower than lower bound. - """ - env = gym.make(ENV_ID) - - with pytest.raises(InvalidBound): - ClipRewardV0(env, lower_bound, upper_bound) diff --git a/tests/experimental/wrappers/test_filter_observation.py b/tests/experimental/wrappers/test_filter_observation.py deleted file mode 100644 index 5232cb840..000000000 --- a/tests/experimental/wrappers/test_filter_observation.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Test suite for FilterObservationV0.""" -from gymnasium.experimental.wrappers import FilterObservationV0 -from gymnasium.spaces import Box, Dict, Tuple -from tests.experimental.wrappers.utils import ( - check_obs, - record_random_obs_reset, - record_random_obs_step, -) -from tests.testing_env import GenericTestEnv - - -def test_filter_observation_wrapper(): - """Tests ``FilterObservation`` that the right keys are filtered.""" - dict_env = GenericTestEnv( - observation_space=Dict(arm_1=Box(0, 1), arm_2=Box(2, 3), arm_3=Box(-1, 1)), - reset_func=record_random_obs_reset, - step_func=record_random_obs_step, - ) - - wrapped_env = FilterObservationV0(dict_env, ("arm_1", "arm_3")) - obs, info = wrapped_env.reset() - assert list(obs.keys()) == ["arm_1", "arm_3"] - assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"] - check_obs(dict_env, wrapped_env, obs, info["obs"]) - - obs, _, _, _, info = wrapped_env.step(None) - assert list(obs.keys()) == ["arm_1", "arm_3"] - assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"] - check_obs(dict_env, wrapped_env, obs, info["obs"]) - - # Test tuple environments - tuple_env = GenericTestEnv( - observation_space=Tuple((Box(0, 1), Box(2, 3), Box(-1, 1))), - reset_func=record_random_obs_reset, - step_func=record_random_obs_step, - ) - wrapped_env = FilterObservationV0(tuple_env, (2,)) - - obs, info = wrapped_env.reset() - assert len(obs) == 1 and len(info["obs"]) == 3 - check_obs(tuple_env, wrapped_env, obs, info["obs"]) - - obs, _, _, _, info = wrapped_env.step(None) - assert len(obs) == 1 and len(info["obs"]) == 3 - check_obs(tuple_env, wrapped_env, obs, info["obs"]) diff --git a/tests/experimental/wrappers/test_flatten_observation.py b/tests/experimental/wrappers/test_flatten_observation.py deleted file mode 100644 index b7ecfb4bf..000000000 --- a/tests/experimental/wrappers/test_flatten_observation.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Test suite for FlattenObservationV0.""" -from gymnasium.experimental.wrappers import FlattenObservationV0 -from gymnasium.spaces import Box, Dict -from tests.experimental.wrappers.utils import ( - check_obs, - record_random_obs_reset, - record_random_obs_step, -) -from tests.testing_env import GenericTestEnv - - -def test_flatten_observation_wrapper(): - """Tests the ``FlattenObservation`` wrapper that the observation are flattened correctly.""" - env = GenericTestEnv( - observation_space=Dict(arm=Box(0, 1), head=Box(2, 3)), - reset_func=record_random_obs_reset, - step_func=record_random_obs_step, - ) - wrapped_env = FlattenObservationV0(env) - - obs, info = wrapped_env.reset() - check_obs(env, wrapped_env, obs, info["obs"]) - - obs, _, _, _, info = wrapped_env.step(None) - check_obs(env, wrapped_env, obs, info["obs"]) diff --git a/tests/experimental/wrappers/test_grayscale_observation.py b/tests/experimental/wrappers/test_grayscale_observation.py deleted file mode 100644 index 74a7a6a17..000000000 --- a/tests/experimental/wrappers/test_grayscale_observation.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Test suite for GrayscaleObservationV0.""" -import numpy as np - -from gymnasium.experimental.wrappers import GrayscaleObservationV0 -from gymnasium.spaces import Box -from tests.experimental.wrappers.utils import ( - check_obs, - record_random_obs_reset, - record_random_obs_step, -) -from tests.testing_env import GenericTestEnv - - -def test_grayscale_observation_wrapper(): - """Tests the ``GrayscaleObservation`` that the observation is grayscale.""" - env = GenericTestEnv( - observation_space=Box(0, 255, shape=(25, 25, 3), dtype=np.uint8), - reset_func=record_random_obs_reset, - step_func=record_random_obs_step, - ) - wrapped_env = GrayscaleObservationV0(env) - - obs, info = wrapped_env.reset() - check_obs(env, wrapped_env, obs, info["obs"]) - assert obs.shape == (25, 25) - - obs, _, _, _, info = wrapped_env.step(None) - check_obs(env, wrapped_env, obs, info["obs"]) - - # Keep_dim - wrapped_env = GrayscaleObservationV0(env, keep_dim=True) - - obs, info = wrapped_env.reset() - check_obs(env, wrapped_env, obs, info["obs"]) - assert obs.shape == (25, 25, 1) - - obs, _, _, _, info = wrapped_env.step(None) - check_obs(env, wrapped_env, obs, info["obs"]) diff --git a/tests/experimental/wrappers/test_human_rendering.py b/tests/experimental/wrappers/test_human_rendering.py deleted file mode 100644 index 2cce4d70f..000000000 --- a/tests/experimental/wrappers/test_human_rendering.py +++ /dev/null @@ -1 +0,0 @@ -"""Test suite for HumanRenderingV0.""" diff --git a/tests/experimental/wrappers/test_import_wrappers.py b/tests/experimental/wrappers/test_import_wrappers.py deleted file mode 100644 index ceb644e4f..000000000 --- a/tests/experimental/wrappers/test_import_wrappers.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Test suite for import wrappers.""" - -import re - -import pytest - -import gymnasium -import gymnasium.experimental.wrappers as wrappers -from gymnasium.experimental.wrappers import __all__ - - -def test_import_wrappers(): - """Test that all wrappers can be imported.""" - # Test that a deprecated wrapper raises a DeprecatedWrapper - with pytest.raises( - wrappers.DeprecatedWrapper, - match=re.escape("'NormalizeRewardV0' is now deprecated"), - ): - getattr(wrappers, "NormalizeRewardV0") - - # Test that an invalid version raises an AttributeError - with pytest.raises( - AttributeError, - match=re.escape( - "module 'gymnasium.experimental.wrappers' has no attribute 'ClipRewardVT', did you mean" - ), - ): - getattr(wrappers, "ClipRewardVT") - - with pytest.raises( - AttributeError, - match=re.escape( - "module 'gymnasium.experimental.wrappers' has no attribute 'ClipRewardV99', did you mean" - ), - ): - getattr(wrappers, "ClipRewardV99") - - # Test that an invalid wrapper raises an AttributeError - with pytest.raises( - AttributeError, - match=re.escape( - "module 'gymnasium.experimental.wrappers' has no attribute 'NonexistentWrapper'" - ), - ): - getattr(wrappers, "NonexistentWrapper") - - -@pytest.mark.parametrize("wrapper_name", __all__) -def test_all_wrappers_shortened(wrapper_name): - """Check that each element of the `__all__` wrappers can be loaded, provided dependencies are installed.""" - try: - assert getattr(gymnasium.experimental.wrappers, wrapper_name) is not None - except gymnasium.error.DependencyNotInstalled as e: - pytest.skip(str(e)) - - -def test_wrapper_vector(): - assert gymnasium.experimental.wrappers.vector is not None diff --git a/tests/experimental/wrappers/test_lambda_rewards.py b/tests/experimental/wrappers/test_lambda_rewards.py deleted file mode 100644 index 697c30c76..000000000 --- a/tests/experimental/wrappers/test_lambda_rewards.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Test lambda reward wrapper.""" - -import numpy as np -import pytest - -import gymnasium as gym -from gymnasium.experimental.wrappers import LambdaRewardV0 -from tests.experimental.wrappers.utils import DISCRETE_ACTION, ENV_ID, NUM_ENVS, SEED - - -@pytest.mark.parametrize( - ("reward_fn", "expected_reward"), - [(lambda r: 2 * r + 1, 3)], -) -def test_lambda_reward(reward_fn, expected_reward): - """Test lambda reward. - - Tests if function is correctly applied - to reward. - """ - env = gym.make(ENV_ID) - env = LambdaRewardV0(env, reward_fn) - env.reset(seed=SEED) - - _, rew, _, _, _ = env.step(DISCRETE_ACTION) - - assert rew == expected_reward - - -@pytest.mark.parametrize( - ( - "reward_fn", - "expected_reward", - ), - [(lambda r: 2 * r + 1, 3)], -) -def test_lambda_reward_within_vector(reward_fn, expected_reward): - """Test lambda reward in vectorized environment. - - Tests if function is correctly applied - to reward in a vectorized environment. - """ - actions = [DISCRETE_ACTION for _ in range(NUM_ENVS)] - env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS) - env = LambdaRewardV0(env, reward_fn) - env.reset(seed=SEED) - - _, rew, _, _, _ = env.step(actions) - - assert np.alltrue(rew == expected_reward) diff --git a/tests/experimental/wrappers/test_normalize_reward.py b/tests/experimental/wrappers/test_normalize_reward.py deleted file mode 100644 index 6621414eb..000000000 --- a/tests/experimental/wrappers/test_normalize_reward.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Test suite for NormalizeRewardV1.""" -import numpy as np - -from gymnasium.core import ActType -from gymnasium.experimental.wrappers import NormalizeRewardV1 -from tests.testing_env import GenericTestEnv - - -def _make_reward_env(): - """Function that returns a `GenericTestEnv` with reward=1.""" - - def step_func(self, action: ActType): - return self.observation_space.sample(), 1.0, False, False, {} - - return GenericTestEnv(step_func=step_func) - - -def test_running_mean_normalize_reward_wrapper(): - """Tests that the property `_update_running_mean` freezes/continues the running statistics updating.""" - env = _make_reward_env() - wrapped_env = NormalizeRewardV1(env) - - # Default value is True - assert wrapped_env.update_running_mean - - wrapped_env.reset() - rms_var_init = wrapped_env.rewards_running_means.var - rms_mean_init = wrapped_env.rewards_running_means.mean - - # Statistics are updated when env.step() - wrapped_env.step(None) - rms_var_updated = wrapped_env.rewards_running_means.var - rms_mean_updated = wrapped_env.rewards_running_means.mean - assert rms_var_init != rms_var_updated - assert rms_mean_init != rms_mean_updated - - # Assure property is set - wrapped_env.update_running_mean = False - assert not wrapped_env.update_running_mean - - # Statistics are frozen - wrapped_env.step(None) - assert rms_var_updated == wrapped_env.rewards_running_means.var - assert rms_mean_updated == wrapped_env.rewards_running_means.mean - - -def test_normalize_reward_wrapper(): - """Tests that the NormalizeReward does not throw an error.""" - # TODO: Functional correctness should be tested - env = _make_reward_env() - wrapped_env = NormalizeRewardV1(env) - wrapped_env.reset() - _, reward, _, _, _ = wrapped_env.step(None) - assert np.ndim(reward) == 0 - env.close() diff --git a/tests/experimental/wrappers/test_numpy_to_torch.py b/tests/experimental/wrappers/test_numpy_to_torch.py deleted file mode 100644 index 6cd9680ec..000000000 --- a/tests/experimental/wrappers/test_numpy_to_torch.py +++ /dev/null @@ -1 +0,0 @@ -"""Test suite for NumpyToTorchV0.""" diff --git a/tests/experimental/wrappers/test_order_enforcing.py b/tests/experimental/wrappers/test_order_enforcing.py deleted file mode 100644 index d513dc779..000000000 --- a/tests/experimental/wrappers/test_order_enforcing.py +++ /dev/null @@ -1 +0,0 @@ -"""Test suite for OrderEnforcingV0.""" diff --git a/tests/experimental/wrappers/test_passive_env_checker.py b/tests/experimental/wrappers/test_passive_env_checker.py deleted file mode 100644 index cd10b83c9..000000000 --- a/tests/experimental/wrappers/test_passive_env_checker.py +++ /dev/null @@ -1 +0,0 @@ -"""Test suite for PassiveEnvCheckerV0.""" diff --git a/tests/experimental/wrappers/test_pixel_observation.py b/tests/experimental/wrappers/test_pixel_observation.py deleted file mode 100644 index 4df32ed93..000000000 --- a/tests/experimental/wrappers/test_pixel_observation.py +++ /dev/null @@ -1 +0,0 @@ -"""Test suite for PixelObservationV0.""" diff --git a/tests/experimental/wrappers/test_record_episode_statistics.py b/tests/experimental/wrappers/test_record_episode_statistics.py deleted file mode 100644 index 1f0ede6ae..000000000 --- a/tests/experimental/wrappers/test_record_episode_statistics.py +++ /dev/null @@ -1 +0,0 @@ -"""Test suite for RecordEpisodeStatisticsV0.""" diff --git a/tests/experimental/wrappers/test_record_video.py b/tests/experimental/wrappers/test_record_video.py deleted file mode 100644 index 1dc05b11d..000000000 --- a/tests/experimental/wrappers/test_record_video.py +++ /dev/null @@ -1,168 +0,0 @@ -"""Test suite for RecordVideoV0.""" -import os -import shutil -from typing import List - -import gymnasium as gym -from gymnasium.experimental.wrappers import RecordVideoV0 - - -def test_record_video_using_default_trigger(): - """Test RecordVideo using the default episode trigger.""" - env = gym.make("CartPole-v1", render_mode="rgb_array_list") - env = RecordVideoV0(env, "videos") - env.reset() - episode_count = 0 - for _ in range(199): - action = env.action_space.sample() - _, _, terminated, truncated, _ = env.step(action) - if terminated or truncated: - env.reset() - episode_count += 1 - - env.close() - assert os.path.isdir("videos") - mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] - assert env.episode_trigger is not None - assert len(mp4_files) == sum( - env.episode_trigger(i) for i in range(episode_count + 1) - ) - shutil.rmtree("videos") - - -def test_record_video_while_rendering(): - """Test RecordVideo while calling render and using a _list render mode.""" - env = gym.make("FrozenLake-v1", render_mode="rgb_array_list") - env = RecordVideoV0(env, "videos") - env.reset() - episode_count = 0 - for _ in range(199): - action = env.action_space.sample() - _, _, terminated, truncated, _ = env.step(action) - env.render() - if terminated or truncated: - env.reset() - episode_count += 1 - - env.close() - assert os.path.isdir("videos") - mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] - assert env.episode_trigger is not None - assert len(mp4_files) == sum( - env.episode_trigger(i) for i in range(episode_count + 1) - ) - shutil.rmtree("videos") - - -def test_record_video_step_trigger(): - """Test RecordVideo defining step trigger function.""" - env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True) - env._max_episode_steps = 20 - env = RecordVideoV0(env, "videos", step_trigger=lambda x: x % 100 == 0) - env.reset() - for _ in range(199): - action = env.action_space.sample() - _, _, terminated, truncated, _ = env.step(action) - if terminated or truncated: - env.reset() - env.close() - assert os.path.isdir("videos") - mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] - shutil.rmtree("videos") - assert len(mp4_files) == 2 - - -def test_record_video_both_trigger(): - """Test RecordVideo defining both step and episode trigger functions.""" - env = gym.make( - "CartPole-v1", render_mode="rgb_array_list", disable_env_checker=True - ) - env._max_episode_steps = 20 - env = RecordVideoV0( - env, - "videos", - step_trigger=lambda x: x == 100, - episode_trigger=lambda x: x == 0 or x == 3, - ) - env.reset() - for _ in range(199): - action = env.action_space.sample() - _, _, terminated, truncated, _ = env.step(action) - if terminated or truncated: - env.reset() - env.close() - assert os.path.isdir("videos") - mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] - shutil.rmtree("videos") - assert len(mp4_files) == 3 - - -def test_record_video_length(): - """Test if argument video_length of RecordVideo works properly.""" - env = gym.make("CartPole-v1", render_mode="rgb_array_list") - env._max_episode_steps = 20 - env = RecordVideoV0(env, "videos", step_trigger=lambda x: x == 0, video_length=10) - env.reset() - for _ in range(10): - action = env.action_space.sample() - env.step(action) - - assert env.recording - action = env.action_space.sample() - env.step(action) - assert not env.recording - env.close() - assert os.path.isdir("videos") - mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] - assert len(mp4_files) == 1 - shutil.rmtree("videos") - - -def test_rendering_works(): - """Test if render output is as expected when the env is wrapped with RecordVideo.""" - env = gym.make("CartPole-v1", render_mode="rgb_array_list") - env._max_episode_steps = 20 - env = RecordVideoV0(env, "videos") - env.reset() - n_steps = 10 - for _ in range(n_steps): - action = env.action_space.sample() - env.step(action) - - render_out = env.render() - assert isinstance(render_out, List) - assert len(render_out) == n_steps + 1 - render_out = env.render() - assert isinstance(render_out, List) - assert len(render_out) == 0 - env.close() - shutil.rmtree("videos") - - -def make_env(gym_id, idx, **kwargs): - """Utility function to make an env and wrap it with RecordVideo only the first time.""" - - def thunk(): - env = gym.make(gym_id, disable_env_checker=True, **kwargs) - env._max_episode_steps = 20 - if idx == 0: - env = RecordVideoV0(env, "videos", step_trigger=lambda x: x % 100 == 0) - return env - - return thunk - - -def test_record_video_within_vector(): - """Test RecordVideo used as env of SyncVectorEnv.""" - envs = gym.vector.SyncVectorEnv( - [make_env("CartPole-v1", i, render_mode="rgb_array") for i in range(2)] - ) - envs = gym.wrappers.RecordEpisodeStatistics(envs) - envs.reset() - for i in range(199): - _, _, _, _, infos = envs.step(envs.action_space.sample()) - - assert os.path.isdir("videos") - mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] - assert len(mp4_files) == 2 - shutil.rmtree("videos") diff --git a/tests/experimental/wrappers/test_render_collection.py b/tests/experimental/wrappers/test_render_collection.py deleted file mode 100644 index 9a2687519..000000000 --- a/tests/experimental/wrappers/test_render_collection.py +++ /dev/null @@ -1 +0,0 @@ -"""Test suite for RenderCollectionV0.""" diff --git a/tests/experimental/wrappers/test_rescale_action.py b/tests/experimental/wrappers/test_rescale_action.py deleted file mode 100644 index efd40559d..000000000 --- a/tests/experimental/wrappers/test_rescale_action.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Test suite for RescaleActionV0.""" -import numpy as np - -from gymnasium.experimental.wrappers import RescaleActionV0 -from gymnasium.spaces import Box -from tests.experimental.wrappers.utils import record_action_step -from tests.testing_env import GenericTestEnv - - -def test_rescale_action_wrapper(): - """Test that the action is rescale within a min / max bound.""" - env = GenericTestEnv( - step_func=record_action_step, - action_space=Box(np.array([0, 1]), np.array([1, 3])), - ) - wrapped_env = RescaleActionV0( - env, min_action=np.array([-5, 0]), max_action=np.array([5, 1]) - ) - assert wrapped_env.action_space == Box(np.array([-5, 0]), np.array([5, 1])) - - for sample_action, expected_action in ( - ( - np.array([0.0, 0.5], dtype=np.float32), - np.array([0.5, 2.0], dtype=np.float32), - ), - ( - np.array([-5.0, 0.0], dtype=np.float32), - np.array([0.0, 1.0], dtype=np.float32), - ), - ( - np.array([5.0, 1.0], dtype=np.float32), - np.array([1.0, 3.0], dtype=np.float32), - ), - ): - assert sample_action in wrapped_env.action_space - - _, _, _, _, info = wrapped_env.step(sample_action) - assert np.all(info["action"] == expected_action) diff --git a/tests/experimental/wrappers/test_resize_observation.py b/tests/experimental/wrappers/test_resize_observation.py deleted file mode 100644 index d663efefa..000000000 --- a/tests/experimental/wrappers/test_resize_observation.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Test suite for ResizeObservationV0.""" -from __future__ import annotations - -import numpy as np -import pytest - -import gymnasium as gym -from gymnasium.experimental.wrappers import ResizeObservationV0 -from gymnasium.spaces import Box -from tests.experimental.wrappers.utils import ( - check_obs, - record_random_obs_reset, - record_random_obs_step, -) -from tests.testing_env import GenericTestEnv - - -@pytest.mark.parametrize( - "env", - ( - GenericTestEnv( - observation_space=Box(0, 255, shape=(60, 60, 3), dtype=np.uint8), - reset_func=record_random_obs_reset, - step_func=record_random_obs_step, - ), - GenericTestEnv( - observation_space=Box(0, 255, shape=(60, 60), dtype=np.uint8), - reset_func=record_random_obs_reset, - step_func=record_random_obs_step, - ), - ), -) -def test_resize_observation_wrapper(env): - """Test the ``ResizeObservation`` that the observation has changed size.""" - - wrapped_env = ResizeObservationV0(env, (25, 25)) - assert isinstance(wrapped_env.observation_space, Box) - assert wrapped_env.observation_space.shape[:2] == (25, 25) - - obs, info = wrapped_env.reset() - check_obs(env, wrapped_env, obs, info["obs"]) - - obs, _, _, _, info = wrapped_env.step(None) - check_obs(env, wrapped_env, obs, info["obs"]) - - -@pytest.mark.parametrize("shape", ((10, 10), (20, 20), (60, 60), (100, 100))) -def test_resize_shapes(shape: tuple[int, int]): - env = ResizeObservationV0(gym.make("CarRacing-v2"), shape) - assert env.observation_space == Box( - low=0, high=255, shape=shape + (3,), dtype=np.uint8 - ) - - obs, info = env.reset() - assert obs in env.observation_space - obs, _, _, _, _ = env.step(env.action_space.sample()) - assert obs in env.observation_space diff --git a/tests/experimental/wrappers/test_time_aware_observation.py b/tests/experimental/wrappers/test_time_aware_observation.py deleted file mode 100644 index 3fc90f769..000000000 --- a/tests/experimental/wrappers/test_time_aware_observation.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Test suite for TimeAwareObservationV0.""" - -from gymnasium.experimental.wrappers import TimeAwareObservationV0 -from gymnasium.spaces import Box, Dict, Tuple -from tests.testing_env import GenericTestEnv - - -def test_env_obs_space(): - """Test the TimeAwareObservation wrapper for three type of observation spaces.""" - env = GenericTestEnv(observation_space=Dict(arm_1=Box(0, 1), arm_2=Box(2, 3))) - wrapped_env = TimeAwareObservationV0(env) - assert isinstance(wrapped_env.observation_space, Dict) - reset_obs, _ = wrapped_env.reset() - step_obs, _, _, _, _ = wrapped_env.step(None) - assert "time" in reset_obs and "time" in step_obs, f"{reset_obs}, {step_obs}" - - assert reset_obs in wrapped_env.observation_space - assert step_obs in wrapped_env.observation_space - - env = GenericTestEnv(observation_space=Tuple((Box(0, 1), Box(2, 3)))) - wrapped_env = TimeAwareObservationV0(env) - assert isinstance(wrapped_env.observation_space, Tuple) - reset_obs, _ = wrapped_env.reset() - step_obs, _, _, _, _ = wrapped_env.step(None) - assert len(reset_obs) == 3 and len(step_obs) == 3 - - assert reset_obs in wrapped_env.observation_space - assert step_obs in wrapped_env.observation_space - - env = GenericTestEnv(observation_space=Box(0, 1)) - wrapped_env = TimeAwareObservationV0(env) - assert isinstance(wrapped_env.observation_space, Dict) - reset_obs, _ = wrapped_env.reset() - step_obs, _, _, _, _ = wrapped_env.step(None) - assert isinstance(reset_obs, dict) and isinstance(step_obs, dict) - assert "obs" in reset_obs and "obs" in step_obs - assert "time" in reset_obs and "time" in step_obs - - assert reset_obs in wrapped_env.observation_space - assert step_obs in wrapped_env.observation_space - - -def test_flatten_parameter(): - """Test the flatten parameter for the TimeAwareObservation wrapper.""" - env = GenericTestEnv(observation_space=Box(0, 1)) - wrapped_env = TimeAwareObservationV0(env, flatten=True) - assert isinstance(wrapped_env.observation_space, Box) - reset_obs, _ = wrapped_env.reset() - step_obs, _, _, _, _ = wrapped_env.step(None) - assert reset_obs.shape == (2,) and step_obs.shape == (2,) - - assert reset_obs in wrapped_env.observation_space - assert step_obs in wrapped_env.observation_space - - -def test_normalize_time_parameter(): - """Test the normalize time parameter for DelayObservation wrappers.""" - # Tests the normalize_time parameter - env = GenericTestEnv(observation_space=Box(0, 1)) - wrapped_env = TimeAwareObservationV0(env, normalize_time=False) - reset_obs, _ = wrapped_env.reset() - step_obs, _, _, _, _ = wrapped_env.step(None) - assert reset_obs["time"] == 100 and step_obs["time"] == 99 - - assert reset_obs in wrapped_env.observation_space - assert step_obs in wrapped_env.observation_space - - env = GenericTestEnv(observation_space=Box(0, 1)) - wrapped_env = TimeAwareObservationV0(env, normalize_time=True) - reset_obs, _ = wrapped_env.reset() - step_obs, _, _, _, _ = wrapped_env.step(None) - assert reset_obs["time"] == 0.0 and step_obs["time"] == 0.01 - - assert reset_obs in wrapped_env.observation_space - assert step_obs in wrapped_env.observation_space diff --git a/tests/experimental/wrappers/utils.py b/tests/experimental/wrappers/utils.py deleted file mode 100644 index 58dd560c1..000000000 --- a/tests/experimental/wrappers/utils.py +++ /dev/null @@ -1,77 +0,0 @@ -"""Utility functions for testing the experimental wrappers.""" -import gymnasium as gym -from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS -from tests.testing_env import GenericTestEnv - - -SEED = 42 -ENV_ID = "CartPole-v1" -DISCRETE_ACTION = 0 -NUM_ENVS = 3 -NUM_STEPS = 20 - - -def record_obs_reset(self: gym.Env, seed=None, options: dict = None): - """Records and uses an observation passed through options.""" - return options["obs"], {"obs": options["obs"]} - - -def record_random_obs_reset(self: gym.Env, seed=None, options=None): - """Records random observation generated by the environment.""" - obs = self.observation_space.sample() - return obs, {"obs": obs} - - -def record_action_step(self: gym.Env, action): - """Records the actions passed to the environment.""" - return 0, 0, False, False, {"action": action} - - -def record_random_obs_step(self: gym.Env, action): - """Records the observation generated by the environment.""" - obs = self.observation_space.sample() - return obs, 0, False, False, {"obs": obs} - - -def record_action_as_obs_step(self: gym.Env, action): - """Uses the action as the observation.""" - return action, 0, False, False, {"obs": action} - - -def check_obs( - env: gym.Env, - wrapped_env: gym.Wrapper, - transformed_obs, - original_obs, - strict: bool = True, -): - """Checks that the original and transformed observations using the environment and wrapped environment. - - Args: - env: The base environment - wrapped_env: The wrapped environment - transformed_obs: The transformed observation by the wrapped environment - original_obs: The original observation by the base environment. - strict: If to check that the observations aren't contained in the other environment. - """ - assert ( - transformed_obs in wrapped_env.observation_space - ), f"{transformed_obs}, {wrapped_env.observation_space}" - assert ( - original_obs in env.observation_space - ), f"{original_obs}, {env.observation_space}" - - if strict: - assert ( - transformed_obs not in env.observation_space - ), f"{transformed_obs}, {env.observation_space}" - assert ( - original_obs not in wrapped_env.observation_space - ), f"{original_obs}, {wrapped_env.observation_space}" - - -TESTING_OBS_ENVS = [GenericTestEnv(observation_space=space) for space in TESTING_SPACES] -TESTING_OBS_ENVS_IDS = TESTING_SPACES_IDS - -TESTING_ACTION_ENVS = [GenericTestEnv(action_space=space) for space in TESTING_SPACES] -TESTING_ACTION_ENVS_IDS = TESTING_SPACES_IDS diff --git a/tests/experimental/wrappers/vector/__init__.py b/tests/experimental/wrappers/vector/__init__.py deleted file mode 100644 index 5e4787d48..000000000 --- a/tests/experimental/wrappers/vector/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Testing suite for `gymnasium.experimental.wrappers.vector`.""" diff --git a/tests/experimental/functional/__init__.py b/tests/functional/__init__.py similarity index 100% rename from tests/experimental/functional/__init__.py rename to tests/functional/__init__.py diff --git a/tests/experimental/functional/test_func_jax_env.py b/tests/functional/test_func_jax_env.py similarity index 100% rename from tests/experimental/functional/test_func_jax_env.py rename to tests/functional/test_func_jax_env.py diff --git a/tests/experimental/functional/test_functional.py b/tests/functional/test_functional.py similarity index 97% rename from tests/experimental/functional/test_functional.py rename to tests/functional/test_functional.py index f33a07552..b47ea7a54 100644 --- a/tests/experimental/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -5,7 +5,7 @@ from typing import Any import numpy as np -from gymnasium.experimental.functional import FuncEnv +from gymnasium.functional import FuncEnv class GenericTestFuncEnv(FuncEnv): diff --git a/tests/experimental/functional/test_jax_blackjack.py b/tests/functional/test_jax_blackjack.py similarity index 100% rename from tests/experimental/functional/test_jax_blackjack.py rename to tests/functional/test_jax_blackjack.py diff --git a/tests/experimental/functional/test_jax_cliffwalking.py b/tests/functional/test_jax_cliffwalking.py similarity index 100% rename from tests/experimental/functional/test_jax_cliffwalking.py rename to tests/functional/test_jax_cliffwalking.py diff --git a/tests/test_core.py b/tests/test_core.py index 048338691..9916620c1 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,184 +1,21 @@ """Checks that the core Gymnasium API is implemented as expected.""" +from __future__ import annotations + import re -from typing import Any, Dict, Optional, SupportsFloat, Tuple +from typing import Any, SupportsFloat import numpy as np import pytest -from gymnasium import Env, ObservationWrapper, RewardWrapper, Wrapper, spaces -from gymnasium.core import ( - ActionWrapper, - ActType, - ObsType, - WrapperActType, - WrapperObsType, -) +import gymnasium as gym +from gymnasium import ActionWrapper, Env, ObservationWrapper, RewardWrapper, Wrapper +from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType from gymnasium.spaces import Box from gymnasium.utils import seeding -from gymnasium.wrappers import OrderEnforcing, TimeLimit +from gymnasium.wrappers import OrderEnforcing from tests.testing_env import GenericTestEnv -# ==== Old testing code - - -class ArgumentEnv(Env): - """Testing environment that records the number of times the environment is created.""" - - observation_space = spaces.Box(low=0, high=1, shape=(1,)) - action_space = spaces.Box(low=0, high=1, shape=(1,)) - calls = 0 - - def __init__(self, arg: Any): - """Constructor.""" - self.calls += 1 - self.arg = arg - - -class UnittestEnv(Env): - """Example testing environment.""" - - observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8) - action_space = spaces.Discrete(3) - - def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): - """Resets the environment.""" - super().reset(seed=seed) - return self.observation_space.sample(), {"info": "dummy"} - - def step(self, action): - """Steps through the environment.""" - observation = self.observation_space.sample() # Dummy observation - return observation, 0.0, False, {} - - -class UnknownSpacesEnv(Env): - """This environment defines its observation & action spaces only after the first call to reset. - - Although this pattern is sometimes necessary when implementing a new environment (e.g. if it depends - on external resources), it is not encouraged. - """ - - def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): - """Resets the environment.""" - super().reset(seed=seed) - self.observation_space = spaces.Box( - low=0, high=255, shape=(64, 64, 3), dtype=np.uint8 - ) - self.action_space = spaces.Discrete(3) - return self.observation_space.sample(), {} # Dummy observation with info - - def step(self, action): - """Steps through the environment.""" - observation = self.observation_space.sample() # Dummy observation - return observation, 0.0, False, {} - - -class OldStyleEnv(Env): - """This environment doesn't accept any arguments in reset, ideally we want to support this too (for now).""" - - def reset(self): - """Resets the environment.""" - super().reset() - return 0 - - def step(self, action): - """Steps through the environment.""" - return 0, 0, False, {} - - -class NewPropertyWrapper(Wrapper): - """Wrapper that tests setting a property.""" - - def __init__( - self, - env, - observation_space=None, - action_space=None, - reward_range=None, - metadata=None, - ): - """New property wrapper. - - Args: - env: The environment to wrap - observation_space: The observation space - action_space: The action space - reward_range: The reward range - metadata: The environment metadata - """ - super().__init__(env) - if observation_space is not None: - # Only set the observation space if not None to test property forwarding - self.observation_space = observation_space - if action_space is not None: - self.action_space = action_space - if reward_range is not None: - self.reward_range = reward_range - if metadata is not None: - self.metadata = metadata - - -def test_env_instantiation(): - """Tests the environment instantiation using ArgumentEnv.""" - # This looks like a pretty trivial, but given our usage of - # __new__, it's worth having. - env = ArgumentEnv("arg") - assert env.arg == "arg" - assert env.calls == 1 - - -properties = [ - { - "observation_space": spaces.Box( - low=0.0, high=1.0, shape=(64, 64, 3), dtype=np.float32 - ) - }, - {"action_space": spaces.Discrete(2)}, - {"reward_range": (-1.0, 1.0)}, - {"metadata": {"render_modes": ["human", "rgb_array_list"]}}, - { - "observation_space": spaces.Box( - low=0.0, high=1.0, shape=(64, 64, 3), dtype=np.float32 - ), - "action_space": spaces.Discrete(2), - }, -] - - -@pytest.mark.parametrize("class_", [UnittestEnv, UnknownSpacesEnv]) -@pytest.mark.parametrize("props", properties) -def test_wrapper_property_forwarding(class_, props): - """Tests wrapper property forwarding.""" - env = class_() - env = NewPropertyWrapper(env, **props) - - # If UnknownSpacesEnv, then call reset to define the spaces - if isinstance(env.unwrapped, UnknownSpacesEnv): - _ = env.reset() - - # Test the properties set by the wrapper - for key, value in props.items(): - assert getattr(env, key) == value - - # Otherwise, test if the properties are forwarded - all_properties = {"observation_space", "action_space", "reward_range", "metadata"} - for key in all_properties - props.keys(): - assert getattr(env, key) == getattr(env.unwrapped, key) - - -def test_compatibility_with_old_style_env(): - """Test compatibility with old style environment.""" - env = OldStyleEnv() - env = OrderEnforcing(env) - env = TimeLimit(env, 100) - obs = env.reset() - assert obs == 0 - - -# ==== New testing code - - class ExampleEnv(Env): """Example testing environment.""" @@ -189,27 +26,26 @@ class ExampleEnv(Env): def step( self, action: ActType - ) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]: + ) -> tuple[ObsType, float, bool, bool, dict[str, Any]]: """Steps through the environment.""" return 0, 0, False, False, {} def reset( self, *, - seed: Optional[int] = None, - options: Optional[dict] = None, - ) -> Tuple[ObsType, dict]: + seed: int | None = None, + options: dict | None = None, + ) -> tuple[ObsType, dict]: """Resets the environment.""" return 0, {} -def test_gymnasium_env(): +def test_example_env(): """Tests a gymnasium environment.""" env = ExampleEnv() assert env.metadata == {"render_modes": []} assert env.render_mode is None - assert env.reward_range == (-float("inf"), float("inf")) assert env.spec is None assert env._np_random is None # pyright: ignore [reportPrivateUsage] @@ -224,14 +60,14 @@ class ExampleWrapper(Wrapper): self.new_reward = 3 def reset( - self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None - ) -> Tuple[WrapperObsType, Dict[str, Any]]: + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[WrapperObsType, dict[str, Any]]: """Resets the environment .""" return super().reset(seed=seed, options=options) def step( self, action: WrapperActType - ) -> Tuple[WrapperObsType, float, bool, bool, Dict[str, Any]]: + ) -> tuple[WrapperObsType, float, bool, bool, dict[str, Any]]: """Steps through the environment.""" obs, reward, termination, truncation, info = self.env.step(action) return obs, self.new_reward, termination, truncation, info @@ -241,7 +77,7 @@ class ExampleWrapper(Wrapper): return self._np_random -def test_gymnasium_wrapper(): +def test_example_wrapper(): """Tests the gymnasium wrapper works as expected.""" env = ExampleEnv() wrapper_env = ExampleWrapper(env) @@ -252,10 +88,6 @@ def test_gymnasium_wrapper(): assert env.render_mode == wrapper_env.render_mode - assert env.reward_range == wrapper_env.reward_range - wrapper_env.reward_range = (-1.0, 1.0) - assert env.reward_range != wrapper_env.reward_range - assert env.spec == wrapper_env.spec env.observation_space = Box(0, 1) @@ -277,7 +109,7 @@ def test_gymnasium_wrapper(): with pytest.raises( AttributeError, match=re.escape( - "Can't access `_np_random` of a wrapper, use `self.unwrapped._np_random` or `self.np_random`." + "Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`." ), ): print(wrapper_env.access_hidden_np_random()) @@ -307,7 +139,7 @@ class ExampleActionWrapper(ActionWrapper): return np.array([1]) -def test_wrapper_types(): +def test_reward_observation_action_wrapper(): """Tests the observation, action and reward wrapper examples.""" env = GenericTestEnv() @@ -326,3 +158,47 @@ def test_wrapper_types(): action_env = ExampleActionWrapper(env) obs, _, _, _, _ = action_env.step(0) assert obs == np.array([1]) + + +def test_get_set_wrapper_attr(): + env = gym.make("CartPole-v1") + + # Test get_wrapper_attr + with pytest.raises(AttributeError): + env.gravity + assert env.unwrapped.gravity is not None + assert env.get_wrapper_attr("gravity") is not None + + with pytest.raises(AttributeError): + env.unknown_attr + with pytest.raises(AttributeError): + env.get_wrapper_attr("unknown_attr") + + # Test set_wrapper_attr + env.set_wrapper_attr("gravity", 10.0) + with pytest.raises(AttributeError): + env.gravity + assert env.unwrapped.gravity == 10.0 + assert env.get_wrapper_attr("gravity") == 10.0 + + env.gravity = 5.0 + assert env.gravity == 5.0 + assert env.get_wrapper_attr("gravity") == 5.0 + assert env.env.get_wrapper_attr("gravity") == 10.0 + + # Test with OrderEnforcing (intermediate wrapper) + assert not isinstance(env, OrderEnforcing) + + with pytest.raises(AttributeError): + env._disable_render_order_enforcing + with pytest.raises(AttributeError): + env.unwrapped._disable_render_order_enforcing + assert env.get_wrapper_attr("_disable_render_order_enforcing") is False + + env.set_wrapper_attr("_disable_render_order_enforcing", True) + + with pytest.raises(AttributeError): + env._disable_render_order_enforcing + with pytest.raises(AttributeError): + env.unwrapped._disable_render_order_enforcing + assert env.get_wrapper_attr("_disable_render_order_enforcing") is True diff --git a/tests/utils/test_save_video.py b/tests/utils/test_save_video.py index 50d465188..f23667880 100644 --- a/tests/utils/test_save_video.py +++ b/tests/utils/test_save_video.py @@ -81,8 +81,11 @@ def test_record_video_within_vector(): n_steps = 199 expected_video = 2 - envs = gym.vector.make( - "CartPole-v1", num_envs=2, asynchronous=True, render_mode="rgb_array_list" + envs = gym.make_vec( + "CartPole-v1", + num_envs=2, + vectorization_mode="sync", + render_mode="rgb_array_list", ) envs.reset() episode_frames = [] diff --git a/tests/vector/__init__.py b/tests/vector/__init__.py index e69de29bb..215d8185c 100644 --- a/tests/vector/__init__.py +++ b/tests/vector/__init__.py @@ -0,0 +1 @@ +"""Testing for `gymnasium.vector`.""" diff --git a/tests/vector/test_async_vector_env.py b/tests/vector/test_async_vector_env.py index 5d654ec9e..6f2609d8d 100644 --- a/tests/vector/test_async_vector_env.py +++ b/tests/vector/test_async_vector_env.py @@ -1,3 +1,5 @@ +"""Test the `SyncVectorEnv` implementation.""" + import re from multiprocessing import TimeoutError @@ -10,8 +12,8 @@ from gymnasium.error import ( NoAsyncCallError, ) from gymnasium.spaces import Box, Discrete, MultiDiscrete, Tuple -from gymnasium.vector.async_vector_env import AsyncVectorEnv -from tests.vector.utils import ( +from gymnasium.vector import AsyncVectorEnv +from tests.vector.testing_utils import ( CustomSpace, make_custom_space_env, make_env, @@ -21,6 +23,7 @@ from tests.vector.utils import ( @pytest.mark.parametrize("shared_memory", [True, False]) def test_create_async_vector_env(shared_memory): + """Test creating an async vector environment with or without shared memory.""" env_fns = [make_env("CartPole-v1", i) for i in range(8)] env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) @@ -30,6 +33,7 @@ def test_create_async_vector_env(shared_memory): @pytest.mark.parametrize("shared_memory", [True, False]) def test_reset_async_vector_env(shared_memory): + """Test the reset of an sync vector environment with or without shared memory.""" env_fns = [make_env("CartPole-v1", i) for i in range(8)] env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) @@ -39,7 +43,6 @@ def test_reset_async_vector_env(shared_memory): assert isinstance(env.observation_space, Box) assert isinstance(observations, np.ndarray) - assert isinstance(infos, dict) assert observations.dtype == env.observation_space.dtype assert observations.shape == (8,) + env.single_observation_space.shape assert observations.shape == env.observation_space.shape @@ -59,13 +62,32 @@ def test_reset_async_vector_env(shared_memory): assert all([isinstance(info, dict) for info in infos]) +def test_render_async_vector(): + envs = AsyncVectorEnv( + [make_env("CartPole-v1", i, render_mode="rgb_array") for i in range(3)] + ) + assert envs.render_mode == "rgb_array" + + envs.reset() + rendered_frames = envs.render() + assert isinstance(rendered_frames, tuple) + assert len(rendered_frames) == envs.num_envs + assert all(isinstance(frame, np.ndarray) for frame in rendered_frames) + envs.close() + + envs = AsyncVectorEnv([make_env("CartPole-v1", i) for i in range(3)]) + assert envs.render_mode is None + envs.close() + + @pytest.mark.parametrize("shared_memory", [True, False]) @pytest.mark.parametrize("use_single_action_space", [True, False]) def test_step_async_vector_env(shared_memory, use_single_action_space): + """Test the step async vector environment with and without shared memory.""" env_fns = [make_env("CartPole-v1", i) for i in range(8)] env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) - observations = env.reset() + env.reset() assert isinstance(env.single_action_space, Discrete) assert isinstance(env.action_space, MultiDiscrete) @@ -74,7 +96,7 @@ def test_step_async_vector_env(shared_memory, use_single_action_space): actions = [env.single_action_space.sample() for _ in range(8)] else: actions = env.action_space.sample() - observations, rewards, terminateds, truncateds, _ = env.step(actions) + observations, rewards, terminations, truncations, _ = env.step(actions) env.close() @@ -89,25 +111,26 @@ def test_step_async_vector_env(shared_memory, use_single_action_space): assert rewards.ndim == 1 assert rewards.size == 8 - assert isinstance(terminateds, np.ndarray) - assert terminateds.dtype == np.bool_ - assert terminateds.ndim == 1 - assert terminateds.size == 8 + assert isinstance(terminations, np.ndarray) + assert terminations.dtype == np.bool_ + assert terminations.ndim == 1 + assert terminations.size == 8 - assert isinstance(truncateds, np.ndarray) - assert truncateds.dtype == np.bool_ - assert truncateds.ndim == 1 - assert truncateds.size == 8 + assert isinstance(truncations, np.ndarray) + assert truncations.dtype == np.bool_ + assert truncations.ndim == 1 + assert truncations.size == 8 @pytest.mark.parametrize("shared_memory", [True, False]) def test_call_async_vector_env(shared_memory): + """Test call with async vector environment.""" env_fns = [ make_env("CartPole-v1", i, render_mode="rgb_array_list") for i in range(4) ] env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) - _ = env.reset() + env.reset() images = env.call("render") gravity = env.call("gravity") @@ -128,6 +151,7 @@ def test_call_async_vector_env(shared_memory): @pytest.mark.parametrize("shared_memory", [True, False]) def test_set_attr_async_vector_env(shared_memory): + """Test `set_attr_` for async vector environment with or without shared memory.""" env_fns = [make_env("CartPole-v1", i) for i in range(4)] env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) @@ -140,6 +164,7 @@ def test_set_attr_async_vector_env(shared_memory): @pytest.mark.parametrize("shared_memory", [True, False]) def test_copy_async_vector_env(shared_memory): + """Test observations are a copy of the true observation with and without shared memory.""" env_fns = [make_env("CartPole-v1", i) for i in range(8)] # TODO, these tests do nothing, understand the purpose of the tests and fix them @@ -152,6 +177,7 @@ def test_copy_async_vector_env(shared_memory): @pytest.mark.parametrize("shared_memory", [True, False]) def test_no_copy_async_vector_env(shared_memory): + """Test observation are not a copy of the true observation with and without shared memory.""" env_fns = [make_env("CartPole-v1", i) for i in range(8)] # TODO, these tests do nothing, understand the purpose of the tests and fix them @@ -164,6 +190,7 @@ def test_no_copy_async_vector_env(shared_memory): @pytest.mark.parametrize("shared_memory", [True, False]) def test_reset_timeout_async_vector_env(shared_memory): + """Test timeout error on reset with and without shared memory.""" env_fns = [make_slow_env(0.3, i) for i in range(4)] env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) @@ -176,18 +203,20 @@ def test_reset_timeout_async_vector_env(shared_memory): @pytest.mark.parametrize("shared_memory", [True, False]) def test_step_timeout_async_vector_env(shared_memory): + """Test timeout error on step with and without shared memory.""" env_fns = [make_slow_env(0.0, i) for i in range(4)] env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) with pytest.raises(TimeoutError): env.reset() env.step_async(np.array([0.1, 0.1, 0.3, 0.1])) - observations, rewards, terminateds, truncateds, _ = env.step_wait(timeout=0.1) + observations, rewards, terminations, truncations, _ = env.step_wait(timeout=0.1) env.close(terminate=True) @pytest.mark.parametrize("shared_memory", [True, False]) def test_reset_out_of_order_async_vector_env(shared_memory): + """Test reset being called out of order with and without shared memory.""" env_fns = [make_env("CartPole-v1", i) for i in range(4)] env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) @@ -224,6 +253,7 @@ def test_reset_out_of_order_async_vector_env(shared_memory): @pytest.mark.parametrize("shared_memory", [True, False]) def test_step_out_of_order_async_vector_env(shared_memory): + """Test step out of order with and without shared memory.""" env_fns = [make_env("CartPole-v1", i) for i in range(4)] env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) @@ -259,6 +289,7 @@ def test_step_out_of_order_async_vector_env(shared_memory): @pytest.mark.parametrize("shared_memory", [True, False]) def test_already_closed_async_vector_env(shared_memory): + """Test the error if a function is called if environment is already closed.""" env_fns = [make_env("CartPole-v1", i) for i in range(4)] with pytest.raises(ClosedEnvironmentError): env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) @@ -268,6 +299,7 @@ def test_already_closed_async_vector_env(shared_memory): @pytest.mark.parametrize("shared_memory", [True, False]) def test_check_spaces_async_vector_env(shared_memory): + """Test check spaces for async vector environment with and without shared memory.""" # CartPole-v1 - observation_space: Box(4,), action_space: Discrete(2) env_fns = [make_env("CartPole-v1", i) for i in range(8)] # FrozenLake-v1 - Discrete(16), action_space: Discrete(4) @@ -278,6 +310,7 @@ def test_check_spaces_async_vector_env(shared_memory): def test_custom_space_async_vector_env(): + """Test custom spaces with async vector environment.""" env_fns = [make_custom_space_env(i) for i in range(4)] env = AsyncVectorEnv(env_fns, shared_memory=False) @@ -287,7 +320,7 @@ def test_custom_space_async_vector_env(): assert isinstance(env.action_space, Tuple) actions = ("action-2", "action-3", "action-5", "action-7") - step_observations, rewards, terminateds, truncateds, _ = env.step(actions) + step_observations, rewards, terminations, truncations, _ = env.step(actions) env.close() @@ -307,6 +340,7 @@ def test_custom_space_async_vector_env(): def test_custom_space_async_vector_env_shared_memory(): + """Test custom space with shared memory.""" env_fns = [make_custom_space_env(i) for i in range(4)] with pytest.raises(ValueError): env = AsyncVectorEnv(env_fns, shared_memory=True) diff --git a/tests/vector/test_numpy_utils.py b/tests/vector/test_numpy_utils.py deleted file mode 100644 index 26e6ad232..000000000 --- a/tests/vector/test_numpy_utils.py +++ /dev/null @@ -1,142 +0,0 @@ -from collections import OrderedDict - -import numpy as np -import pytest - -from gymnasium.spaces import Dict, Tuple -from gymnasium.vector.utils.numpy_utils import concatenate, create_empty_array -from gymnasium.vector.utils.spaces import BaseGymSpaces -from tests.vector.utils import spaces - - -@pytest.mark.parametrize( - "space", spaces, ids=[space.__class__.__name__ for space in spaces] -) -def test_concatenate(space): - def assert_type(lhs, rhs, n): - # Special case: if rhs is a list of scalars, lhs must be an np.ndarray - if np.isscalar(rhs[0]): - assert isinstance(lhs, np.ndarray) - assert all([np.isscalar(rhs[i]) for i in range(n)]) - else: - assert all([isinstance(rhs[i], type(lhs)) for i in range(n)]) - - def assert_nested_equal(lhs, rhs, n): - assert isinstance(rhs, list) - assert (n > 0) and (len(rhs) == n) - assert_type(lhs, rhs, n) - if isinstance(lhs, np.ndarray): - assert lhs.shape[0] == n - for i in range(n): - assert np.all(lhs[i] == rhs[i]) - - elif isinstance(lhs, tuple): - for i in range(len(lhs)): - rhs_T_i = [rhs[j][i] for j in range(n)] - assert_nested_equal(lhs[i], rhs_T_i, n) - - elif isinstance(lhs, OrderedDict): - for key in lhs.keys(): - rhs_T_key = [rhs[j][key] for j in range(n)] - assert_nested_equal(lhs[key], rhs_T_key, n) - - else: - raise TypeError(f"Got unknown type `{type(lhs)}`.") - - samples = [space.sample() for _ in range(8)] - array = create_empty_array(space, n=8) - concatenated = concatenate(space, samples, array) - - assert np.all(concatenated == array) - assert_nested_equal(array, samples, n=8) - - -@pytest.mark.parametrize("n", [1, 8]) -@pytest.mark.parametrize( - "space", spaces, ids=[space.__class__.__name__ for space in spaces] -) -def test_create_empty_array(space, n): - def assert_nested_type(arr, space, n): - if isinstance(space, BaseGymSpaces): - assert isinstance(arr, np.ndarray) - assert arr.dtype == space.dtype - assert arr.shape == (n,) + space.shape - - elif isinstance(space, Tuple): - assert isinstance(arr, tuple) - assert len(arr) == len(space.spaces) - for i in range(len(arr)): - assert_nested_type(arr[i], space.spaces[i], n) - - elif isinstance(space, Dict): - assert isinstance(arr, OrderedDict) - assert set(arr.keys()) ^ set(space.spaces.keys()) == set() - for key in arr.keys(): - assert_nested_type(arr[key], space.spaces[key], n) - - else: - raise TypeError(f"Got unknown type `{type(arr)}`.") - - array = create_empty_array(space, n=n, fn=np.empty) - assert_nested_type(array, space, n=n) - - -@pytest.mark.parametrize("n", [1, 8]) -@pytest.mark.parametrize( - "space", spaces, ids=[space.__class__.__name__ for space in spaces] -) -def test_create_empty_array_zeros(space, n): - def assert_nested_type(arr, space, n): - if isinstance(space, BaseGymSpaces): - assert isinstance(arr, np.ndarray) - assert arr.dtype == space.dtype - assert arr.shape == (n,) + space.shape - assert np.all(arr == 0) - - elif isinstance(space, Tuple): - assert isinstance(arr, tuple) - assert len(arr) == len(space.spaces) - for i in range(len(arr)): - assert_nested_type(arr[i], space.spaces[i], n) - - elif isinstance(space, Dict): - assert isinstance(arr, OrderedDict) - assert set(arr.keys()) ^ set(space.spaces.keys()) == set() - for key in arr.keys(): - assert_nested_type(arr[key], space.spaces[key], n) - - else: - raise TypeError(f"Got unknown type `{type(arr)}`.") - - array = create_empty_array(space, n=n, fn=np.zeros) - assert_nested_type(array, space, n=n) - - -@pytest.mark.parametrize( - "space", spaces, ids=[space.__class__.__name__ for space in spaces] -) -def test_create_empty_array_none_shape_ones(space): - def assert_nested_type(arr, space): - if isinstance(space, BaseGymSpaces): - assert isinstance(arr, np.ndarray) - assert arr.dtype == space.dtype - assert arr.shape == space.shape - assert np.all(arr == 1) - - elif isinstance(space, Tuple): - assert isinstance(arr, tuple) - assert len(arr) == len(space.spaces) - for i in range(len(arr)): - assert_nested_type(arr[i], space.spaces[i]) - - elif isinstance(space, Dict): - assert isinstance(arr, OrderedDict) - assert set(arr.keys()) ^ set(space.spaces.keys()) == set() - for key in arr.keys(): - assert_nested_type(arr[key], space.spaces[key]) - - else: - raise TypeError(f"Got unknown type `{type(arr)}`.") - - array = create_empty_array(space, n=None, fn=np.ones) - assert_nested_type(array, space) diff --git a/tests/vector/test_shared_memory.py b/tests/vector/test_shared_memory.py deleted file mode 100644 index b6bdfef2a..000000000 --- a/tests/vector/test_shared_memory.py +++ /dev/null @@ -1,189 +0,0 @@ -import multiprocessing as mp -from collections import OrderedDict -from multiprocessing import Array, Process -from multiprocessing.sharedctypes import SynchronizedArray - -import numpy as np -import pytest - -from gymnasium.error import CustomSpaceError -from gymnasium.spaces import Dict, Tuple -from gymnasium.vector.utils.shared_memory import ( - create_shared_memory, - read_from_shared_memory, - write_to_shared_memory, -) -from gymnasium.vector.utils.spaces import BaseGymSpaces -from tests.vector.utils import custom_spaces, spaces - - -expected_types = [ - Array("d", 1), - Array("f", 1), - Array("f", 3), - Array("f", 4), - Array("B", 1), - Array("B", 32 * 32 * 3), - Array("i", 1), - Array("i", 1), - (Array("i", 1), Array("i", 1)), - (Array("i", 1), Array("f", 2)), - Array("B", 3), - Array("B", 3), - Array("B", 19), - OrderedDict([("position", Array("i", 1)), ("velocity", Array("f", 1))]), - OrderedDict( - [ - ("position", OrderedDict([("x", Array("i", 1)), ("y", Array("i", 1))])), - ("velocity", (Array("i", 1), Array("B", 1))), - ] - ), -] - - -@pytest.mark.parametrize("n", [1, 8]) -@pytest.mark.parametrize( - "space,expected_type", - list(zip(spaces, expected_types)), - ids=[space.__class__.__name__ for space in spaces], -) -@pytest.mark.parametrize( - "ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"] -) -def test_create_shared_memory(space, expected_type, n, ctx): - if ctx not in mp.get_all_start_methods(): - pytest.skip( - f"Multiprocessing start method {ctx} not available on this platform." - ) - - def assert_nested_type(lhs, rhs, n): - assert type(lhs) == type(rhs) - if isinstance(lhs, (list, tuple)): - assert len(lhs) == len(rhs) - for lhs_, rhs_ in zip(lhs, rhs): - assert_nested_type(lhs_, rhs_, n) - - elif isinstance(lhs, (dict, OrderedDict)): - assert set(lhs.keys()) ^ set(rhs.keys()) == set() - for key in lhs.keys(): - assert_nested_type(lhs[key], rhs[key], n) - - elif isinstance(lhs, SynchronizedArray): - # Assert the length of the array - assert len(lhs[:]) == n * len(rhs[:]) - # Assert the data type - assert isinstance(lhs[0], type(rhs[0])) - else: - raise TypeError(f"Got unknown type `{type(lhs)}`.") - - ctx = mp if (ctx is None) else mp.get_context(ctx) - shared_memory = create_shared_memory(space, n=n, ctx=ctx) - assert_nested_type(shared_memory, expected_type, n=n) - - -@pytest.mark.parametrize("n", [1, 8]) -@pytest.mark.parametrize( - "ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"] -) -@pytest.mark.parametrize("space", custom_spaces) -def test_create_shared_memory_custom_space(n, ctx, space): - if ctx not in mp.get_all_start_methods(): - pytest.skip( - f"Multiprocessing start method {ctx} not available on this platform." - ) - - ctx = mp if (ctx is None) else mp.get_context(ctx) - with pytest.raises(CustomSpaceError): - create_shared_memory(space, n=n, ctx=ctx) - - -def _write_shared_memory(space, i, shared_memory, sample): - write_to_shared_memory(space, i, sample, shared_memory) - - -@pytest.mark.parametrize( - "space", spaces, ids=[space.__class__.__name__ for space in spaces] -) -def test_write_to_shared_memory(space): - def assert_nested_equal(lhs, rhs): - assert isinstance(rhs, list) - if isinstance(lhs, (list, tuple)): - for i in range(len(lhs)): - assert_nested_equal(lhs[i], [rhs_[i] for rhs_ in rhs]) - - elif isinstance(lhs, (dict, OrderedDict)): - for key in lhs.keys(): - assert_nested_equal(lhs[key], [rhs_[key] for rhs_ in rhs]) - - elif isinstance(lhs, SynchronizedArray): - assert np.all(np.array(lhs[:]) == np.stack(rhs, axis=0).flatten()) - - else: - raise TypeError(f"Got unknown type `{type(lhs)}`.") - - shared_memory_n8 = create_shared_memory(space, n=8) - samples = [space.sample() for _ in range(8)] - - processes = [ - Process( - target=_write_shared_memory, args=(space, i, shared_memory_n8, samples[i]) - ) - for i in range(8) - ] - - for process in processes: - process.start() - for process in processes: - process.join() - - assert_nested_equal(shared_memory_n8, samples) - - -def _process_write(space, i, shared_memory, sample): - write_to_shared_memory(space, i, sample, shared_memory) - - -@pytest.mark.parametrize( - "space", spaces, ids=[space.__class__.__name__ for space in spaces] -) -def test_read_from_shared_memory(space): - def assert_nested_equal(lhs, rhs, space, n): - assert isinstance(rhs, list) - if isinstance(space, Tuple): - assert isinstance(lhs, tuple) - for i in range(len(lhs)): - assert_nested_equal( - lhs[i], [rhs_[i] for rhs_ in rhs], space.spaces[i], n - ) - - elif isinstance(space, Dict): - assert isinstance(lhs, OrderedDict) - for key in lhs.keys(): - assert_nested_equal( - lhs[key], [rhs_[key] for rhs_ in rhs], space.spaces[key], n - ) - - elif isinstance(space, BaseGymSpaces): - assert isinstance(lhs, np.ndarray) - assert lhs.shape == ((n,) + space.shape) - assert lhs.dtype == space.dtype - assert np.all(lhs == np.stack(rhs, axis=0)) - - else: - raise TypeError(f"Got unknown type `{type(space)}`") - - shared_memory_n8 = create_shared_memory(space, n=8) - memory_view_n8 = read_from_shared_memory(space, shared_memory_n8, n=8) - samples = [space.sample() for _ in range(8)] - - processes = [ - Process(target=_process_write, args=(space, i, shared_memory_n8, samples[i])) - for i in range(8) - ] - - for process in processes: - process.start() - for process in processes: - process.join() - - assert_nested_equal(memory_view_n8, samples, space, n=8) diff --git a/tests/vector/test_spaces.py b/tests/vector/test_spaces.py deleted file mode 100644 index 517d9ffef..000000000 --- a/tests/vector/test_spaces.py +++ /dev/null @@ -1,205 +0,0 @@ -import copy - -import numpy as np -import pytest -from numpy.testing import assert_array_equal - -from gymnasium.spaces import Box, Dict, MultiDiscrete, Space, Tuple -from gymnasium.vector.utils.spaces import batch_space, iterate -from tests.vector.utils import CustomSpace, assert_rng_equal, custom_spaces, spaces - - -expected_batch_spaces_4 = [ - Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float64), - Box(low=0.0, high=10.0, shape=(4, 1), dtype=np.float64), - Box( - low=np.array( - [[-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]] - ), - high=np.array( - [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] - ), - dtype=np.float64, - ), - Box( - low=np.array( - [ - [[-1.0, 0.0], [0.0, -1.0]], - [[-1.0, 0.0], [0.0, -1.0]], - [[-1.0, 0.0], [0.0, -1]], - [[-1.0, 0.0], [0.0, -1.0]], - ] - ), - high=np.ones((4, 2, 2)), - dtype=np.float64, - ), - Box(low=0, high=255, shape=(4,), dtype=np.uint8), - Box(low=0, high=255, shape=(4, 32, 32, 3), dtype=np.uint8), - MultiDiscrete([2, 2, 2, 2]), - MultiDiscrete([5, 5, 5, 5], start=[-2, -2, -2, -2]), - Tuple((MultiDiscrete([3, 3, 3, 3]), MultiDiscrete([5, 5, 5, 5]))), - Tuple( - ( - MultiDiscrete([7, 7, 7, 7]), - Box( - low=np.array([[0.0, -1.0], [0.0, -1.0], [0.0, -1.0], [0.0, -1]]), - high=np.array([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]), - dtype=np.float64, - ), - ) - ), - Box( - low=np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]), - high=np.array([[10, 12, 16], [10, 12, 16], [10, 12, 16], [10, 12, 16]]), - dtype=np.int64, - ), - Box( - low=np.array([[-5, -7, -9], [-5, -7, -9], [-5, -7, -9], [-5, -7, -9]]), - high=np.array([[4, 6, 8], [4, 6, 8], [4, 6, 8], [4, 6, 8]]), - dtype=np.int64, - ), - Box(low=0, high=1, shape=(4, 19), dtype=np.int8), - Dict( - { - "position": MultiDiscrete([23, 23, 23, 23]), - "velocity": Box(low=0.0, high=1.0, shape=(4, 1), dtype=np.float64), - } - ), - Dict( - { - "position": Dict( - { - "x": MultiDiscrete([29, 29, 29, 29]), - "y": MultiDiscrete([31, 31, 31, 31]), - } - ), - "velocity": Tuple( - ( - MultiDiscrete([37, 37, 37, 37]), - Box(low=0, high=255, shape=(4,), dtype=np.uint8), - ) - ), - } - ), -] - -expected_custom_batch_spaces_4 = [ - Tuple((CustomSpace(), CustomSpace(), CustomSpace(), CustomSpace())), - Tuple( - ( - Tuple((CustomSpace(), CustomSpace(), CustomSpace(), CustomSpace())), - Box(low=0, high=255, shape=(4,), dtype=np.uint8), - ) - ), -] - - -@pytest.mark.parametrize( - "space,expected_batch_space_4", - list(zip(spaces, expected_batch_spaces_4)), - ids=[space.__class__.__name__ for space in spaces], -) -def test_batch_space(space, expected_batch_space_4): - batch_space_4 = batch_space(space, n=4) - assert batch_space_4 == expected_batch_space_4 - - -@pytest.mark.parametrize( - "space,expected_batch_space_4", - list(zip(custom_spaces, expected_custom_batch_spaces_4)), - ids=[space.__class__.__name__ for space in custom_spaces], -) -def test_batch_space_custom_space(space, expected_batch_space_4): - batch_space_4 = batch_space(space, n=4) - assert batch_space_4 == expected_batch_space_4 - - -@pytest.mark.parametrize( - "space,batch_space", - list(zip(spaces, expected_batch_spaces_4)), - ids=[space.__class__.__name__ for space in spaces], -) -def test_iterate(space, batch_space): - items = batch_space.sample() - iterator = iterate(batch_space, items) - i = 0 - for i, item in enumerate(iterator): - assert item in space - assert i == 3 - - -@pytest.mark.parametrize( - "space,batch_space", - list(zip(custom_spaces, expected_custom_batch_spaces_4)), - ids=[space.__class__.__name__ for space in custom_spaces], -) -def test_iterate_custom_space(space, batch_space): - items = batch_space.sample() - iterator = iterate(batch_space, items) - i = 0 - for i, item in enumerate(iterator): - assert item in space - assert i == 3 - - -@pytest.mark.parametrize( - "space", spaces, ids=[space.__class__.__name__ for space in spaces] -) -@pytest.mark.parametrize("n", [4, 5], ids=[f"n={n}" for n in [4, 5]]) -@pytest.mark.parametrize( - "base_seed", [123, 456], ids=[f"seed={base_seed}" for base_seed in [123, 456]] -) -def test_rng_different_at_each_index(space: Space, n: int, base_seed: int): - """ - Tests that the rng values produced at each index are different - to prevent if the rng is copied for each subspace - """ - space.seed(base_seed) - - batched_space = batch_space(space, n) - assert space.np_random is not batched_space.np_random - assert_rng_equal(space.np_random, batched_space.np_random) - - batched_sample = batched_space.sample() - sample = list(iterate(batched_space, batched_sample)) - assert not all(np.all(element == sample[0]) for element in sample), sample - - -@pytest.mark.parametrize( - "space", spaces, ids=[space.__class__.__name__ for space in spaces] -) -@pytest.mark.parametrize("n", [1, 2, 5], ids=[f"n={n}" for n in [1, 2, 5]]) -@pytest.mark.parametrize( - "base_seed", [123, 456], ids=[f"seed={base_seed}" for base_seed in [123, 456]] -) -def test_deterministic(space: Space, n: int, base_seed: int): - """Tests the batched spaces are deterministic by using a copied version""" - # Copy the spaces and check that the np_random are not reference equal - space_a = space - space_a.seed(base_seed) - space_b = copy.deepcopy(space_a) - assert_rng_equal(space_a.np_random, space_b.np_random) - assert space_a.np_random is not space_b.np_random - - # Batch the spaces and check that the np_random are not reference equal - space_a_batched = batch_space(space_a, n) - space_b_batched = batch_space(space_b, n) - assert_rng_equal(space_a_batched.np_random, space_b_batched.np_random) - assert space_a_batched.np_random is not space_b_batched.np_random - # Create that the batched space is not reference equal to the origin spaces - assert space_a.np_random is not space_a_batched.np_random - - # Check that batched space a and b random number generator are not effected by the original space - space_a.sample() - space_a_batched_sample = space_a_batched.sample() - space_b_batched_sample = space_b_batched.sample() - for a_sample, b_sample in zip( - iterate(space_a_batched, space_a_batched_sample), - iterate(space_b_batched, space_b_batched_sample), - ): - if isinstance(a_sample, tuple): - assert len(a_sample) == len(b_sample) - for a_subsample, b_subsample in zip(a_sample, b_sample): - assert_array_equal(a_subsample, b_subsample) - else: - assert_array_equal(a_sample, b_sample) diff --git a/tests/vector/test_sync_vector_env.py b/tests/vector/test_sync_vector_env.py index 2ccdaab6b..8ec187b98 100644 --- a/tests/vector/test_sync_vector_env.py +++ b/tests/vector/test_sync_vector_env.py @@ -1,11 +1,13 @@ +"""Test the `SyncVectorEnv` implementation.""" + import numpy as np import pytest from gymnasium.envs.registration import EnvSpec from gymnasium.spaces import Box, Discrete, MultiDiscrete, Tuple -from gymnasium.vector.sync_vector_env import SyncVectorEnv +from gymnasium.vector import SyncVectorEnv from tests.envs.utils import all_testing_env_specs -from tests.vector.utils import ( +from tests.vector.testing_utils import ( CustomSpace, assert_rng_equal, make_custom_space_env, @@ -14,6 +16,7 @@ from tests.vector.utils import ( def test_create_sync_vector_env(): + """Tests creating the sync vector environment.""" env_fns = [make_env("FrozenLake-v1", i) for i in range(8)] env = SyncVectorEnv(env_fns) env.close() @@ -22,6 +25,7 @@ def test_create_sync_vector_env(): def test_reset_sync_vector_env(): + """Tests sync vector `reset` function.""" env_fns = [make_env("CartPole-v1", i) for i in range(8)] env = SyncVectorEnv(env_fns) observations, infos = env.reset() @@ -29,7 +33,6 @@ def test_reset_sync_vector_env(): assert isinstance(env.observation_space, Box) assert isinstance(observations, np.ndarray) - assert isinstance(infos, dict) assert observations.dtype == env.observation_space.dtype assert observations.shape == (8,) + env.single_observation_space.shape assert observations.shape == env.observation_space.shape @@ -39,10 +42,9 @@ def test_reset_sync_vector_env(): @pytest.mark.parametrize("use_single_action_space", [True, False]) def test_step_sync_vector_env(use_single_action_space): - env_fns = [make_env("FrozenLake-v1", i) for i in range(8)] - - env = SyncVectorEnv(env_fns) - observations = env.reset() + """Test sync vector `steps` function.""" + env = SyncVectorEnv([make_env("FrozenLake-v1", i) for i in range(8)]) + env.reset() assert isinstance(env.single_action_space, Discrete) assert isinstance(env.action_space, MultiDiscrete) @@ -51,7 +53,7 @@ def test_step_sync_vector_env(use_single_action_space): actions = [env.single_action_space.sample() for _ in range(8)] else: actions = env.action_space.sample() - observations, rewards, terminateds, truncateds, _ = env.step(actions) + observations, rewards, terminations, truncations, _ = env.step(actions) env.close() @@ -66,18 +68,35 @@ def test_step_sync_vector_env(use_single_action_space): assert rewards.ndim == 1 assert rewards.size == 8 - assert isinstance(terminateds, np.ndarray) - assert terminateds.dtype == np.bool_ - assert terminateds.ndim == 1 - assert terminateds.size == 8 + assert isinstance(terminations, np.ndarray) + assert terminations.dtype == np.bool_ + assert terminations.ndim == 1 + assert terminations.size == 8 - assert isinstance(truncateds, np.ndarray) - assert truncateds.dtype == np.bool_ - assert truncateds.ndim == 1 - assert truncateds.size == 8 + assert isinstance(truncations, np.ndarray) + assert truncations.dtype == np.bool_ + assert truncations.ndim == 1 + assert truncations.size == 8 + + +def test_render_sync_vector(): + envs = SyncVectorEnv( + [make_env("CartPole-v1", i, render_mode="rgb_array") for i in range(3)] + ) + assert envs.render_mode == "rgb_array" + + envs.reset() + rendered_frames = envs.render() + assert isinstance(rendered_frames, tuple) + assert len(rendered_frames) == envs.num_envs + assert all(isinstance(frame, np.ndarray) for frame in rendered_frames) + + envs = SyncVectorEnv([make_env("CartPole-v1", i) for i in range(3)]) + assert envs.render_mode is None def test_call_sync_vector_env(): + """Test sync vector `call` on sub-environments.""" env_fns = [ make_env("CartPole-v1", i, render_mode="rgb_array_list") for i in range(4) ] @@ -103,6 +122,7 @@ def test_call_sync_vector_env(): def test_set_attr_sync_vector_env(): + """Test sync vector `set_attr` function.""" env_fns = [make_env("CartPole-v1", i) for i in range(4)] env = SyncVectorEnv(env_fns) @@ -114,6 +134,7 @@ def test_set_attr_sync_vector_env(): def test_check_spaces_sync_vector_env(): + """Tests the sync vector `check_spaces` function.""" # CartPole-v1 - observation_space: Box(4,), action_space: Discrete(2) env_fns = [make_env("CartPole-v1", i) for i in range(8)] # FrozenLake-v1 - Discrete(16), action_space: Discrete(4) @@ -124,6 +145,7 @@ def test_check_spaces_sync_vector_env(): def test_custom_space_sync_vector_env(): + """Test the use of custom spaces with sync vector environment.""" env_fns = [make_custom_space_env(i) for i in range(4)] env = SyncVectorEnv(env_fns) @@ -131,18 +153,15 @@ def test_custom_space_sync_vector_env(): assert isinstance(env.single_action_space, CustomSpace) assert isinstance(env.action_space, Tuple) - assert isinstance(infos, dict) actions = ("action-2", "action-3", "action-5", "action-7") - step_observations, rewards, terminateds, truncateds, infos = env.step(actions) + step_observations, _, _, _, _ = env.step(actions) env.close() assert isinstance(env.single_observation_space, CustomSpace) assert isinstance(env.observation_space, Tuple) - assert isinstance(infos, dict) - assert isinstance(reset_observations, tuple) assert reset_observations == ("reset", "reset", "reset", "reset") @@ -156,6 +175,7 @@ def test_custom_space_sync_vector_env(): def test_sync_vector_env_seed(): + """Test seeding for sync vector environments.""" env = make_env("BipedalWalker-v3", seed=123)() sync_vector_env = SyncVectorEnv([make_env("BipedalWalker-v3", seed=123)]) @@ -165,12 +185,14 @@ def test_sync_vector_env_seed(): vector_action = sync_vector_env.action_space.sample() assert np.all(env_action == vector_action) + env.close() + @pytest.mark.parametrize( "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs] ) def test_sync_vector_determinism(spec: EnvSpec, seed: int = 123, n: int = 3): - """Check that for all environments, the sync vector envs produce the same action samples using the same seeds""" + """Check that for all environments, the sync vector envs produce the same action samples using the same seeds.""" env_1 = SyncVectorEnv([make_env(spec.id, seed=seed) for _ in range(n)]) env_2 = SyncVectorEnv([make_env(spec.id, seed=seed) for _ in range(n)]) assert_rng_equal(env_1.action_space.np_random, env_2.action_space.np_random) @@ -179,3 +201,6 @@ def test_sync_vector_determinism(spec: EnvSpec, seed: int = 123, n: int = 3): env_1_samples = env_1.action_space.sample() env_2_samples = env_2.action_space.sample() assert np.all(env_1_samples == env_2_samples) + + env_1.close() + env_2.close() diff --git a/tests/vector/test_vector_env.py b/tests/vector/test_vector_env.py index 16ad5eba6..211c36882 100644 --- a/tests/vector/test_vector_env.py +++ b/tests/vector/test_vector_env.py @@ -1,18 +1,19 @@ +"""Test vector environment implementations.""" + from functools import partial import numpy as np import pytest -from gymnasium.spaces import Discrete, Tuple -from gymnasium.vector.async_vector_env import AsyncVectorEnv -from gymnasium.vector.sync_vector_env import SyncVectorEnv -from gymnasium.vector.vector_env import VectorEnv +from gymnasium.spaces import Discrete +from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv from tests.testing_env import GenericTestEnv -from tests.vector.utils import CustomSpace, make_env +from tests.vector.testing_utils import make_env @pytest.mark.parametrize("shared_memory", [True, False]) def test_vector_env_equal(shared_memory): + """Test that vector environment are equal for both async and sync variants.""" env_fns = [make_env("CartPole-v1", i) for i in range(4)] num_steps = 100 @@ -33,12 +34,22 @@ def test_vector_env_equal(shared_memory): actions = async_env.action_space.sample() assert actions in sync_env.action_space - # fmt: off - async_observations, async_rewards, async_terminateds, async_truncateds, async_infos = async_env.step(actions) - sync_observations, sync_rewards, sync_terminateds, sync_truncateds, sync_infos = sync_env.step(actions) - # fmt: on + ( + async_observations, + async_rewards, + async_terminations, + async_truncations, + async_infos, + ) = async_env.step(actions) + ( + sync_observations, + sync_rewards, + sync_terminations, + sync_truncations, + sync_infos, + ) = sync_env.step(actions) - if any(sync_terminateds) or any(sync_truncateds): + if any(sync_terminations) or any(sync_truncations): assert "final_observation" in async_infos assert "_final_observation" in async_infos assert "final_observation" in sync_infos @@ -46,23 +57,13 @@ def test_vector_env_equal(shared_memory): assert np.all(async_observations == sync_observations) assert np.all(async_rewards == sync_rewards) - assert np.all(async_terminateds == sync_terminateds) - assert np.all(async_truncateds == sync_truncateds) + assert np.all(async_terminations == sync_terminations) + assert np.all(async_truncations == sync_truncations) async_env.close() sync_env.close() -def test_custom_space_vector_env(): - env = VectorEnv(4, CustomSpace(), CustomSpace()) - - assert isinstance(env.single_observation_space, CustomSpace) - assert isinstance(env.observation_space, Tuple) - - assert isinstance(env.single_action_space, CustomSpace) - assert isinstance(env.action_space, Tuple) - - @pytest.mark.parametrize( "vectoriser", ( @@ -123,3 +124,5 @@ def test_final_obs_info(vectoriser): assert info["final_observation"] == np.array([0]) and info["final_info"] == { "action": 3 } + + env.close() diff --git a/tests/vector/test_vector_env_info.py b/tests/vector/test_vector_env_info.py index 1cf333e1d..6711aa12d 100644 --- a/tests/vector/test_vector_env_info.py +++ b/tests/vector/test_vector_env_info.py @@ -1,9 +1,10 @@ +"""Test the vector environment information.""" import numpy as np import pytest import gymnasium as gym from gymnasium.vector.sync_vector_env import SyncVectorEnv -from tests.vector.utils import make_env +from tests.vector.testing_utils import make_env ENV_ID = "CartPole-v1" @@ -12,10 +13,13 @@ ENV_STEPS = 50 SEED = 42 -@pytest.mark.parametrize("asynchronous", [True, False]) -def test_vector_env_info(asynchronous): - env = gym.vector.make( - ENV_ID, num_envs=NUM_ENVS, asynchronous=asynchronous, disable_env_checker=True +@pytest.mark.parametrize("vectorization_mode", ["async", "sync"]) +def test_vector_env_info(vectorization_mode: str): + """Test vector environment info for different vectorization modes.""" + env = gym.make_vec( + ENV_ID, + num_envs=NUM_ENVS, + vectorization_mode=vectorization_mode, ) env.reset(seed=SEED) for _ in range(ENV_STEPS): @@ -36,9 +40,12 @@ def test_vector_env_info(asynchronous): assert not infos["_final_observation"][i] assert infos["final_observation"][i] is None + env.close() + @pytest.mark.parametrize("concurrent_ends", [1, 2, 3]) def test_vector_env_info_concurrent_termination(concurrent_ends): + """Test the vector environment information works with concurrent termination.""" # envs that need to terminate together will have the same action actions = [0] * concurrent_ends + [1] * (NUM_ENVS - concurrent_ends) envs = [make_env(ENV_ID, SEED) for _ in range(NUM_ENVS)] @@ -55,3 +62,5 @@ def test_vector_env_info_concurrent_termination(concurrent_ends): assert not infos["_final_observation"][i] assert infos["final_observation"][i] is None return + + envs.close() diff --git a/tests/vector/test_vector_env_wrapper.py b/tests/vector/test_vector_env_wrapper.py deleted file mode 100644 index 33c240e8c..000000000 --- a/tests/vector/test_vector_env_wrapper.py +++ /dev/null @@ -1,31 +0,0 @@ -import numpy as np - -from gymnasium.vector import VectorEnvWrapper, make - - -class DummyWrapper(VectorEnvWrapper): - def __init__(self, env): - self.env = env - self.counter = 0 - - def reset_async(self, **kwargs): - super().reset_async() - self.counter += 1 - - -def test_vector_env_wrapper_inheritance(): - env = make("FrozenLake-v1", asynchronous=False) - wrapped = DummyWrapper(env) - wrapped.reset() - assert wrapped.counter == 1 - - -def test_vector_env_wrapper_attributes(): - """Test if `set_attr`, `call` methods for VecEnvWrapper get correctly forwarded to the vector env it is wrapping.""" - env = make("CartPole-v1", num_envs=3) - wrapped = DummyWrapper(make("CartPole-v1", num_envs=3)) - - assert np.allclose(wrapped.call("gravity"), env.call("gravity")) - env.set_attr("gravity", [20.0, 20.0, 20.0]) - wrapped.set_attr("gravity", [20.0, 20.0, 20.0]) - assert np.allclose(wrapped.get_attr("gravity"), env.get_attr("gravity")) diff --git a/tests/vector/test_vector_make.py b/tests/vector/test_vector_make.py deleted file mode 100644 index 7150009aa..000000000 --- a/tests/vector/test_vector_make.py +++ /dev/null @@ -1,83 +0,0 @@ -import pytest - -import gymnasium as gym -from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv -from gymnasium.wrappers import TimeLimit, TransformObservation -from gymnasium.wrappers.env_checker import PassiveEnvChecker -from tests.wrappers.utils import has_wrapper - - -def test_vector_make_id(): - env = gym.vector.make("CartPole-v1") - assert isinstance(env, AsyncVectorEnv) - assert env.num_envs == 1 - env.close() - - -@pytest.mark.parametrize("num_envs", [1, 3, 10]) -def test_vector_make_num_envs(num_envs): - env = gym.vector.make("CartPole-v1", num_envs=num_envs) - assert env.num_envs == num_envs - env.close() - - -def test_vector_make_asynchronous(): - env = gym.vector.make("CartPole-v1", asynchronous=True) - assert isinstance(env, AsyncVectorEnv) - env.close() - - env = gym.vector.make("CartPole-v1", asynchronous=False) - assert isinstance(env, SyncVectorEnv) - env.close() - - -def test_vector_make_wrappers(): - env = gym.vector.make("CartPole-v1", num_envs=2, asynchronous=False) - assert isinstance(env, SyncVectorEnv) - assert len(env.envs) == 2 - - sub_env = env.envs[0] - assert isinstance(sub_env, gym.Env) - assert sub_env.spec is not None - if sub_env.spec.max_episode_steps is not None: - assert has_wrapper(sub_env, TimeLimit) - - assert all( - has_wrapper(sub_env, TransformObservation) is False for sub_env in env.envs - ) - env.close() - - env = gym.vector.make( - "CartPole-v1", - num_envs=2, - asynchronous=False, - wrappers=lambda _env: TransformObservation(_env, lambda obs: obs * 2), - ) - # As asynchronous environment are inaccessible, synchronous vector must be used - assert isinstance(env, SyncVectorEnv) - assert all(has_wrapper(sub_env, TransformObservation) for sub_env in env.envs) - - env.close() - - -def test_vector_make_disable_env_checker(): - # As asynchronous environment are inaccessible, synchronous vector must be used - env = gym.vector.make("CartPole-v1", num_envs=1, asynchronous=False) - assert isinstance(env, SyncVectorEnv) - assert has_wrapper(env.envs[0], PassiveEnvChecker) - env.close() - - env = gym.vector.make("CartPole-v1", num_envs=5, asynchronous=False) - assert isinstance(env, SyncVectorEnv) - assert has_wrapper(env.envs[0], PassiveEnvChecker) - assert all( - has_wrapper(env.envs[i], PassiveEnvChecker) is False for i in [1, 2, 3, 4] - ) - env.close() - - env = gym.vector.make( - "CartPole-v1", num_envs=3, asynchronous=False, disable_env_checker=True - ) - assert isinstance(env, SyncVectorEnv) - assert all(has_wrapper(sub_env, PassiveEnvChecker) is False for sub_env in env.envs) - env.close() diff --git a/tests/experimental/vector/test_vector_wrapper.py b/tests/vector/test_vector_wrapper.py similarity index 51% rename from tests/experimental/vector/test_vector_wrapper.py rename to tests/vector/test_vector_wrapper.py index 1a4ce81db..2701ecbb7 100644 --- a/tests/experimental/vector/test_vector_wrapper.py +++ b/tests/vector/test_vector_wrapper.py @@ -1,8 +1,13 @@ """Tests the vector wrappers work as expected.""" +from __future__ import annotations + +from typing import Any + import numpy as np import gymnasium as gym -from gymnasium.experimental.vector import VectorWrapper +from gymnasium.core import ObsType +from gymnasium.vector import VectorWrapper class DummyVectorWrapper(VectorWrapper): @@ -11,29 +16,41 @@ class DummyVectorWrapper(VectorWrapper): def __init__(self, env): """Initialises the wrapper with the environment creating a counter variable.""" super().__init__(env) - self.env = env + self.counter = 0 - def reset(self, **kwargs): + def reset( + self, + *, + seed: int | list[int] | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[ObsType, dict[str, Any]]: """Updates the ``counter`` each time at ``reset`` is called.""" - super().reset() self.counter += 1 + return super().reset(seed=seed, options=options) + def test_vector_env_wrapper_inheritance(): """Test vector environment wrapper inheritance.""" - env = gym.make_vec("FrozenLake-v1", vectorization_mode="async") + env = gym.make_vec("FrozenLake-v1", vectorization_mode="sync") wrapped = DummyVectorWrapper(env) wrapped.reset() assert wrapped.counter == 1 + env.close() + def test_vector_env_wrapper_attributes(): """Test if `set_attr`, `call` methods for VecEnvWrapper get correctly forwarded to the vector env it is wrapping.""" - env = gym.make_vec("CartPole-v1", num_envs=3) - wrapped = DummyVectorWrapper(gym.make_vec("CartPole-v1", num_envs=3)) + env = gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync") + wrapped = DummyVectorWrapper( + gym.make_vec("CartPole-v1", num_envs=3, vectorization_mode="sync") + ) - assert np.allclose(wrapped.call("gravity"), env.call("gravity")) + assert np.allclose(wrapped.env.call("gravity"), env.call("gravity")) env.set_attr("gravity", [20.0, 20.0, 20.0]) - wrapped.set_attr("gravity", [20.0, 20.0, 20.0]) - assert np.allclose(wrapped.get_attr("gravity"), env.get_attr("gravity")) + wrapped.env.set_attr("gravity", [20.0, 20.0, 20.0]) + assert np.allclose(wrapped.env.get_attr("gravity"), env.get_attr("gravity")) + + env.close() diff --git a/tests/experimental/vector/testing_utils.py b/tests/vector/testing_utils.py similarity index 98% rename from tests/experimental/vector/testing_utils.py rename to tests/vector/testing_utils.py index 7cde743e0..8f35dff83 100644 --- a/tests/experimental/vector/testing_utils.py +++ b/tests/vector/testing_utils.py @@ -1,4 +1,4 @@ -"""Testing utilitys for `gymnasium.experimental.vector`.""" +"""Testing utilitys for `gymnasium.vector`.""" import time from typing import Optional diff --git a/tests/vector/utils.py b/tests/vector/utils.py deleted file mode 100644 index 6e8adaa4f..000000000 --- a/tests/vector/utils.py +++ /dev/null @@ -1,141 +0,0 @@ -import time -from typing import Optional - -import numpy as np - -import gymnasium as gym -from gymnasium.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple -from gymnasium.utils.seeding import RandomNumberGenerator - - -spaces = [ - Box(low=np.array(-1.0), high=np.array(1.0), dtype=np.float64), - Box(low=np.array([0.0]), high=np.array([10.0]), dtype=np.float64), - Box( - low=np.array([-1.0, 0.0, 0.0]), high=np.array([1.0, 1.0, 1.0]), dtype=np.float64 - ), - Box( - low=np.array([[-1.0, 0.0], [0.0, -1.0]]), high=np.ones((2, 2)), dtype=np.float64 - ), - Box(low=0, high=255, shape=(), dtype=np.uint8), - Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8), - Discrete(2), - Discrete(5, start=-2), - Tuple((Discrete(3), Discrete(5))), - Tuple( - ( - Discrete(7), - Box(low=np.array([0.0, -1.0]), high=np.array([1.0, 1.0]), dtype=np.float64), - ) - ), - MultiDiscrete([11, 13, 17]), - MultiDiscrete([10, 14, 18], start=[-5, -7, -9]), - MultiBinary(19), - Dict( - { - "position": Discrete(23), - "velocity": Box( - low=np.array([0.0]), high=np.array([1.0]), dtype=np.float64 - ), - } - ), - Dict( - { - "position": Dict({"x": Discrete(29), "y": Discrete(31)}), - "velocity": Tuple( - (Discrete(37), Box(low=0, high=255, shape=(), dtype=np.uint8)) - ), - } - ), -] - -HEIGHT, WIDTH = 64, 64 - - -class UnittestSlowEnv(gym.Env): - def __init__(self, slow_reset=0.3): - super().__init__() - self.slow_reset = slow_reset - self.observation_space = Box( - low=0, high=255, shape=(HEIGHT, WIDTH, 3), dtype=np.uint8 - ) - self.action_space = Box(low=0.0, high=1.0, shape=(), dtype=np.float32) - - def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): - super().reset(seed=seed) - if self.slow_reset > 0: - time.sleep(self.slow_reset) - return self.observation_space.sample(), {} - - def step(self, action): - time.sleep(action) - observation = self.observation_space.sample() - reward, terminated, truncated = 0.0, False, False - return observation, reward, terminated, truncated, {} - - -class CustomSpace(gym.Space): - """Minimal custom observation space.""" - - def sample(self): - return self.np_random.integers(0, 10, ()) - - def contains(self, x): - return 0 <= x <= 10 - - def __eq__(self, other): - return isinstance(other, CustomSpace) - - -custom_spaces = [ - CustomSpace(), - Tuple((CustomSpace(), Box(low=0, high=255, shape=(), dtype=np.uint8))), -] - - -class CustomSpaceEnv(gym.Env): - def __init__(self): - super().__init__() - self.observation_space = CustomSpace() - self.action_space = CustomSpace() - - def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): - super().reset(seed=seed) - return "reset", {} - - def step(self, action): - observation = f"step({action:s})" - reward, terminated, truncated = 0.0, False, False - return observation, reward, terminated, truncated, {} - - -def make_env(env_name, seed, **kwargs): - def _make(): - env = gym.make(env_name, disable_env_checker=True, **kwargs) - env.action_space.seed(seed) - env.reset(seed=seed) - return env - - return _make - - -def make_slow_env(slow_reset, seed): - def _make(): - env = UnittestSlowEnv(slow_reset=slow_reset) - env.reset(seed=seed) - return env - - return _make - - -def make_custom_space_env(seed): - def _make(): - env = CustomSpaceEnv() - env.reset(seed=seed) - return env - - return _make - - -def assert_rng_equal(rng_1: RandomNumberGenerator, rng_2: RandomNumberGenerator): - assert rng_1.bit_generator.state == rng_2.bit_generator.state diff --git a/tests/vector/utils/__init__.py b/tests/vector/utils/__init__.py new file mode 100644 index 000000000..56d1d7e57 --- /dev/null +++ b/tests/vector/utils/__init__.py @@ -0,0 +1 @@ +"""Module for testing `gymnasium.vector.utils` functions.""" diff --git a/tests/experimental/vector/utils/test_shared_memory.py b/tests/vector/utils/test_shared_memory.py similarity index 96% rename from tests/experimental/vector/utils/test_shared_memory.py rename to tests/vector/utils/test_shared_memory.py index daea7c2f7..c4b732c81 100644 --- a/tests/experimental/vector/utils/test_shared_memory.py +++ b/tests/vector/utils/test_shared_memory.py @@ -1,4 +1,4 @@ -"""Tests `gymnasium.experimental.vector.utils.shared_memory functions.""" +"""Tests `gymnasium.vector.utils.shared_memory functions.""" import multiprocessing as mp import re @@ -7,12 +7,12 @@ import pytest from gymnasium import Space from gymnasium.error import CustomSpaceError -from gymnasium.experimental.vector.utils import ( +from gymnasium.utils.env_checker import data_equivalence +from gymnasium.vector.utils import ( create_shared_memory, read_from_shared_memory, write_to_shared_memory, ) -from gymnasium.utils.env_checker import data_equivalence from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS diff --git a/tests/experimental/vector/utils/test_space_utils.py b/tests/vector/utils/test_space_utils.py similarity index 96% rename from tests/experimental/vector/utils/test_space_utils.py rename to tests/vector/utils/test_space_utils.py index ee4bdaff1..18bfc9cab 100644 --- a/tests/experimental/vector/utils/test_space_utils.py +++ b/tests/vector/utils/test_space_utils.py @@ -1,4 +1,4 @@ -"""Testing `gymnasium.experimental.vector.utils.space_utils` functions.""" +"""Testing `gymnasium.vector.utils.space_utils` functions.""" import copy import re @@ -8,16 +8,11 @@ import pytest from gymnasium import Space from gymnasium.error import CustomSpaceError -from gymnasium.experimental.vector.utils import ( - batch_space, - concatenate, - create_empty_array, - iterate, -) from gymnasium.spaces import Tuple from gymnasium.utils.env_checker import data_equivalence -from tests.experimental.vector.utils.utils import is_rng_equal +from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS, CustomSpace +from tests.vector.utils.utils import is_rng_equal @pytest.mark.parametrize("space", TESTING_SPACES, ids=TESTING_SPACES_IDS) diff --git a/tests/experimental/vector/utils/utils.py b/tests/vector/utils/utils.py similarity index 100% rename from tests/experimental/vector/utils/utils.py rename to tests/vector/utils/utils.py diff --git a/tests/wrappers/__init__.py b/tests/wrappers/__init__.py index e69de29bb..d076ba110 100644 --- a/tests/wrappers/__init__.py +++ b/tests/wrappers/__init__.py @@ -0,0 +1 @@ +"""Test suite for the wrappers.""" diff --git a/tests/wrappers/test_atari_preprocessing.py b/tests/wrappers/test_atari_preprocessing.py index 3c8121588..851f26477 100644 --- a/tests/wrappers/test_atari_preprocessing.py +++ b/tests/wrappers/test_atari_preprocessing.py @@ -1,9 +1,11 @@ +"""Test suite for AtariProcessing wrapper.""" + import numpy as np import pytest from gymnasium.spaces import Box, Discrete -from gymnasium.wrappers import AtariPreprocessing, StepAPICompatibility -from tests.testing_env import GenericTestEnv, old_step_func +from gymnasium.wrappers import AtariPreprocessing +from tests.testing_env import GenericTestEnv class AleTesting: @@ -34,7 +36,6 @@ class AtariTestingEnv(GenericTestEnv): low=0, high=255, shape=(210, 160, 3), dtype=np.uint8, seed=1 ), action_space=Discrete(3, seed=1), - step_func=old_step_func, ) self.ale = AleTesting() @@ -44,12 +45,12 @@ class AtariTestingEnv(GenericTestEnv): @pytest.mark.parametrize( - "env, obs_shape", + "env, expected_obs_shape", [ (AtariTestingEnv(), (210, 160, 3)), ( AtariPreprocessing( - StepAPICompatibility(AtariTestingEnv(), output_truncation_bool=True), + AtariTestingEnv(), screen_size=84, grayscale_obs=True, frame_skip=1, @@ -59,7 +60,7 @@ class AtariTestingEnv(GenericTestEnv): ), ( AtariPreprocessing( - StepAPICompatibility(AtariTestingEnv(), output_truncation_bool=True), + AtariTestingEnv(), screen_size=84, grayscale_obs=False, frame_skip=1, @@ -69,7 +70,7 @@ class AtariTestingEnv(GenericTestEnv): ), ( AtariPreprocessing( - StepAPICompatibility(AtariTestingEnv(), output_truncation_bool=True), + AtariTestingEnv(), screen_size=84, grayscale_obs=True, frame_skip=1, @@ -80,15 +81,8 @@ class AtariTestingEnv(GenericTestEnv): ), ], ) -def test_atari_preprocessing_grayscale(env, obs_shape): - assert env.observation_space.shape == obs_shape - - # It is not possible to test the outputs as we are not using actual observations. - # todo: update when ale-py is compatible with the ci - - env = StepAPICompatibility( - env, output_truncation_bool=True - ) # using compatibility wrapper since ale-py uses old step API +def test_atari_preprocessing_grayscale(env, expected_obs_shape): + assert env.observation_space.shape == expected_obs_shape obs, _ = env.reset(seed=0) assert obs in env.observation_space @@ -104,7 +98,7 @@ def test_atari_preprocessing_grayscale(env, obs_shape): def test_atari_preprocessing_scale(grayscale, scaled, max_test_steps=10): # arbitrarily chosen number for stepping into env. and ensuring all observations are in the required range env = AtariPreprocessing( - StepAPICompatibility(AtariTestingEnv(), output_truncation_bool=True), + AtariTestingEnv(), screen_size=84, grayscale_obs=grayscale, scale_obs=scaled, diff --git a/tests/wrappers/test_autoreset.py b/tests/wrappers/test_autoreset.py index f72047eed..1431b061f 100644 --- a/tests/wrappers/test_autoreset.py +++ b/tests/wrappers/test_autoreset.py @@ -1,104 +1,31 @@ -"""Tests the gymnasium.wrapper.AutoResetWrapper operates as expected.""" -from typing import Generator, Optional -from unittest.mock import MagicMock - +"""Test suite for Autoreset wrapper.""" import numpy as np -import pytest import gymnasium as gym -from gymnasium.wrappers import AutoResetWrapper -from tests.envs.utils import all_testing_env_specs +from gymnasium.wrappers import Autoreset +from tests.testing_env import GenericTestEnv -class DummyResetEnv(gym.Env): - """A dummy environment which returns ascending numbers starting at `0` when :meth:`self.step()` is called. - - After the second call to :meth:`self.step()` terminated is true. - Info dicts are also returned containing the same number returned as an observation, accessible via the key "count". - This environment is provided for the purpose of testing the autoreset wrapper. - """ - - metadata = {} - - def __init__(self): - """Initialise the DummyResetEnv.""" - self.action_space = gym.spaces.Box( - low=np.array([0]), high=np.array([2]), dtype=np.int64 - ) - self.observation_space = gym.spaces.Discrete(2) - self.count = 0 - - def step(self, action: int): - """Steps the DummyEnv with the incremented step, reward and terminated `if self.count > 1` and updated info.""" - self.count += 1 - return ( - np.array([self.count]), # Obs - self.count > 2, # Reward - self.count > 2, # Terminated - False, # Truncated - {"count": self.count}, # Info - ) - - def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): - """Resets the DummyEnv to return the count array and info with count.""" - self.count = 0 - return np.array([self.count]), {"count": self.count} +def autoreset_reset_func(self: gym.Env, seed=None, options=None): + self.count = 0 + return np.array([self.count]), {"count": self.count} -def unwrap_env(env) -> Generator[gym.Wrapper, None, None]: - """Unwraps an environment yielding all wrappers around environment.""" - while isinstance(env, gym.Wrapper): - yield type(env) - env = env.env - - -@pytest.mark.parametrize( - "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs] -) -def test_make_autoreset_true(spec): - """Tests gym.make with `autoreset=True`, and check that the reset actually happens. - - Note: This test assumes that the outermost wrapper is AutoResetWrapper so if that - is being changed in the future, this test will break and need to be updated. - Note: This test assumes that all first-party environments will terminate in a finite - amount of time with random actions, which is true as of the time of adding this test. - """ - env = gym.make(spec.id, autoreset=True, disable_env_checker=True) - assert AutoResetWrapper in unwrap_env(env) - - env.reset(seed=0) - env.unwrapped.reset = MagicMock(side_effect=env.unwrapped.reset) - - terminated, truncated = False, False - while not (terminated or truncated): - obs, reward, terminated, truncated, info = env.step(env.action_space.sample()) - - assert env.unwrapped.reset.called - env.close() - - -@pytest.mark.parametrize( - "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs] -) -def test_gym_make_autoreset(spec): - """Tests that `gym.make` autoreset wrapper is applied only when `gym.make(..., autoreset=True)`.""" - env = gym.make(spec.id, disable_env_checker=True) - assert AutoResetWrapper not in unwrap_env(env) - env.close() - - env = gym.make(spec.id, autoreset=False, disable_env_checker=True) - assert AutoResetWrapper not in unwrap_env(env) - env.close() - - env = gym.make(spec.id, autoreset=True, disable_env_checker=True) - assert AutoResetWrapper in unwrap_env(env) - env.close() +def autoreset_step_func(self: gym.Env, action: int): + self.count += 1 + return ( + np.array([self.count]), # Obs + self.count > 2, # Reward + self.count > 2, # Terminated + False, # Truncated + {"count": self.count}, # Info + ) def test_autoreset_wrapper_autoreset(): """Tests the autoreset wrapper actually automatically resets correctly.""" - env = DummyResetEnv() - env = AutoResetWrapper(env) + env = GenericTestEnv(reset_func=autoreset_reset_func, step_func=autoreset_step_func) + env = Autoreset(env) obs, info = env.reset() assert obs == np.array([0]) diff --git a/tests/wrappers/test_clip_action.py b/tests/wrappers/test_clip_action.py index 1cb9b804e..8b6e94d94 100644 --- a/tests/wrappers/test_clip_action.py +++ b/tests/wrappers/test_clip_action.py @@ -1,28 +1,25 @@ +"""Test suite for ClipAction wrapper.""" + import numpy as np -import gymnasium as gym +from gymnasium.spaces import Box from gymnasium.wrappers import ClipAction +from tests.testing_env import GenericTestEnv +from tests.wrappers.utils import record_action_step -def test_clip_action(): - # mountaincar: action-based rewards - env = gym.make("MountainCarContinuous-v0", disable_env_checker=True) - wrapped_env = ClipAction( - gym.make("MountainCarContinuous-v0", disable_env_checker=True) +def test_clip_action_wrapper(): + """Test that the action is correctly clipped to the base environment action space.""" + env = GenericTestEnv( + action_space=Box(np.array([0, 0, 3]), np.array([1, 2, 4])), + step_func=record_action_step, ) + wrapped_env = ClipAction(env) - seed = 0 + sampled_action = np.array([-1, 5, 3.5], dtype=np.float32) + assert sampled_action not in env.action_space + assert sampled_action in wrapped_env.action_space - env.reset(seed=seed) - wrapped_env.reset(seed=seed) - - actions = [[0.4], [1.2], [-0.3], [0.0], [-2.5]] - for action in actions: - obs1, r1, ter1, trunc1, _ = env.step( - np.clip(action, env.action_space.low, env.action_space.high) - ) - obs2, r2, ter2, trunc2, _ = wrapped_env.step(action) - assert np.allclose(r1, r2) - assert np.allclose(obs1, obs2) - assert ter1 == ter2 - assert trunc1 == trunc2 + _, _, _, _, info = wrapped_env.step(sampled_action) + assert np.all(info["action"] in env.action_space) + assert np.all(info["action"] == np.array([0, 2, 3.5])) diff --git a/tests/wrappers/test_clip_reward.py b/tests/wrappers/test_clip_reward.py new file mode 100644 index 000000000..42b3606f2 --- /dev/null +++ b/tests/wrappers/test_clip_reward.py @@ -0,0 +1,42 @@ +"""Test suite for ClipReward wrapper.""" +import numpy as np +import pytest + +import gymnasium as gym +from gymnasium.error import InvalidBound +from gymnasium.wrappers import ClipReward +from tests.wrappers.utils import DISCRETE_ACTION, ENV_ID, SEED + + +@pytest.mark.parametrize( + ("lower_bound", "upper_bound", "expected_reward"), + [(None, 0.5, 0.5), (0, None, 1), (0, 0.5, 0.5)], +) +def test_clip_reward_wrapper(lower_bound, upper_bound, expected_reward): + """Test reward clipping. + + Test if reward is correctly clipped accordingly to the input args. + """ + env = gym.make(ENV_ID) + env = ClipReward(env, lower_bound, upper_bound) + env.reset(seed=SEED) + _, rew, _, _, _ = env.step(DISCRETE_ACTION) + + assert rew == expected_reward + + +@pytest.mark.parametrize( + ("lower_bound", "upper_bound"), + [(None, None), (1, -1), (np.array([1, 1]), np.array([0, 0]))], +) +def test_clip_reward_incorrect_params(lower_bound, upper_bound): + """Test reward clipping with incorrect params. + + Test whether passing wrong params to clip_rewards correctly raise an exception. + clip_rewards should raise an exception if, both low and upper bound of reward are `None` + or if upper bound is lower than lower bound. + """ + env = gym.make(ENV_ID) + + with pytest.raises(InvalidBound): + ClipReward(env, lower_bound, upper_bound) diff --git a/tests/experimental/wrappers/test_delay_observation.py b/tests/wrappers/test_delay_observation.py similarity index 81% rename from tests/experimental/wrappers/test_delay_observation.py rename to tests/wrappers/test_delay_observation.py index 470d18013..d47daf6a8 100644 --- a/tests/experimental/wrappers/test_delay_observation.py +++ b/tests/wrappers/test_delay_observation.py @@ -1,17 +1,13 @@ -"""Test suite for DelayObservationV0.""" +"""Test suite for DelayObservation wrapper.""" import re import pytest import gymnasium as gym -from gymnasium.experimental.wrappers import DelayObservationV0 -from gymnasium.experimental.wrappers.utils import create_zero_array from gymnasium.utils.env_checker import data_equivalence -from tests.experimental.wrappers.utils import ( - SEED, - TESTING_OBS_ENVS, - TESTING_OBS_ENVS_IDS, -) +from gymnasium.wrappers import DelayObservation +from gymnasium.wrappers.utils import create_zero_array +from tests.wrappers.utils import SEED, TESTING_OBS_ENVS, TESTING_OBS_ENVS_IDS @pytest.mark.parametrize("env", TESTING_OBS_ENVS, ids=TESTING_OBS_ENVS_IDS) @@ -25,7 +21,7 @@ def test_env_obs(env, delay: int = 3, extra_steps: int = 4): obs, _, _, _, _ = env.step(env.action_space.sample()) undelayed_obs.append(obs) - env = DelayObservationV0(env, delay=delay) + env = DelayObservation(env, delay=delay) example_zero_obs = create_zero_array(env.observation_space) env.action_space.seed(SEED) obs, _ = env.reset(seed=SEED) @@ -50,7 +46,7 @@ def test_delay_values(delay): env = gym.make("CartPole-v1") first_obs, _ = env.reset(seed=123) - env = DelayObservationV0(gym.make("CartPole-v1"), delay=delay) + env = DelayObservation(gym.make("CartPole-v1"), delay=delay) zero_obs = create_zero_array(env.observation_space) obs, _ = env.reset(seed=123) assert data_equivalence(obs, zero_obs) @@ -72,10 +68,10 @@ def test_delay_failures(): "The delay is expected to be an integer, actual type: " ), ): - DelayObservationV0(env, delay=1.0) + DelayObservation(env, delay=1.0) with pytest.raises( ValueError, match=re.escape("The delay needs to be greater than zero, actual value: -1"), ): - DelayObservationV0(env, delay=-1) + DelayObservation(env, delay=-1) diff --git a/tests/experimental/wrappers/test_dtype_observation.py b/tests/wrappers/test_dtype_observation.py similarity index 65% rename from tests/experimental/wrappers/test_dtype_observation.py rename to tests/wrappers/test_dtype_observation.py index 1d6233638..c54d52228 100644 --- a/tests/experimental/wrappers/test_dtype_observation.py +++ b/tests/wrappers/test_dtype_observation.py @@ -1,12 +1,9 @@ -"""Test suite for DtypeObservationV0.""" +"""Test suite for DtypeObservation wrapper.""" import numpy as np -from gymnasium.experimental.wrappers import DtypeObservationV0 -from tests.experimental.wrappers.utils import ( - record_random_obs_reset, - record_random_obs_step, -) +from gymnasium.wrappers import DtypeObservation from tests.testing_env import GenericTestEnv +from tests.wrappers.utils import record_random_obs_reset, record_random_obs_step def test_dtype_observation(): @@ -14,7 +11,7 @@ def test_dtype_observation(): env = GenericTestEnv( reset_func=record_random_obs_reset, step_func=record_random_obs_step ) - wrapped_env = DtypeObservationV0(env, dtype=np.uint8) + wrapped_env = DtypeObservation(env, dtype=np.uint8) obs, info = wrapped_env.reset() assert obs.dtype != info["obs"].dtype diff --git a/tests/wrappers/test_filter_observation.py b/tests/wrappers/test_filter_observation.py index 217536ace..51bbd48d8 100644 --- a/tests/wrappers/test_filter_observation.py +++ b/tests/wrappers/test_filter_observation.py @@ -1,87 +1,81 @@ -from typing import Optional, Tuple - -import numpy as np +"""Test suite for FilterObservation wrapper.""" import pytest -import gymnasium as gym -from gymnasium import spaces -from gymnasium.wrappers.filter_observation import FilterObservation - - -class FakeEnvironment(gym.Env): - def __init__( - self, render_mode=None, observation_keys: Tuple[str, ...] = ("state",) - ): - self.observation_space = spaces.Dict( - { - name: spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32) - for name in observation_keys - } - ) - self.action_space = spaces.Box(shape=(1,), low=-1, high=1, dtype=np.float32) - self.render_mode = render_mode - - def render(self, mode="human"): - image_shape = (32, 32, 3) - return np.zeros(image_shape, dtype=np.uint8) - - def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): - super().reset(seed=seed) - observation = self.observation_space.sample() - return observation, {} - - def step(self, action): - del action - observation = self.observation_space.sample() - reward, terminal, info = 0.0, False, {} - return observation, reward, terminal, info - - -FILTER_OBSERVATION_TEST_CASES = ( - (("key1", "key2"), ("key1",)), - (("key1", "key2"), ("key1", "key2")), - (("key1",), None), - (("key1",), ("key1",)), -) - -ERROR_TEST_CASES = ( - ("key", ValueError, "All the filter_keys must be included..*"), - (False, TypeError, "'bool' object is not iterable"), - (1, TypeError, "'int' object is not iterable"), +from gymnasium.spaces import Box, Dict, Tuple +from gymnasium.wrappers import FilterObservation +from tests.testing_env import GenericTestEnv +from tests.wrappers.utils import ( + check_obs, + record_random_obs_reset, + record_random_obs_step, ) -class TestFilterObservation: - @pytest.mark.parametrize( - "observation_keys,filter_keys", FILTER_OBSERVATION_TEST_CASES +def test_filter_observation_wrapper(): + """Tests ``FilterObservation`` that the right keys are filtered.""" + dict_env = GenericTestEnv( + observation_space=Dict(arm_1=Box(0, 1), arm_2=Box(2, 3), arm_3=Box(-1, 1)), + reset_func=record_random_obs_reset, + step_func=record_random_obs_step, ) - def test_filter_observation(self, observation_keys, filter_keys): - env = FakeEnvironment(observation_keys=observation_keys) - # Make sure we are testing the right environment for the test. - observation_space = env.observation_space - assert isinstance(observation_space, spaces.Dict) + wrapped_env = FilterObservation(dict_env, ("arm_1", "arm_3")) + obs, info = wrapped_env.reset() + assert list(obs.keys()) == ["arm_1", "arm_3"] + assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"] + check_obs(dict_env, wrapped_env, obs, info["obs"]) - wrapped_env = FilterObservation(env, filter_keys=filter_keys) + obs, _, _, _, info = wrapped_env.step(None) + assert list(obs.keys()) == ["arm_1", "arm_3"] + assert list(info["obs"].keys()) == ["arm_1", "arm_2", "arm_3"] + check_obs(dict_env, wrapped_env, obs, info["obs"]) - assert isinstance(wrapped_env.observation_space, spaces.Dict) + # Test tuple environments + tuple_env = GenericTestEnv( + observation_space=Tuple((Box(0, 1), Box(2, 3), Box(-1, 1))), + reset_func=record_random_obs_reset, + step_func=record_random_obs_step, + ) + wrapped_env = FilterObservation(tuple_env, (2,)) - if filter_keys is None: - filter_keys = tuple(observation_keys) + obs, info = wrapped_env.reset() + assert len(obs) == 1 and len(info["obs"]) == 3 + check_obs(tuple_env, wrapped_env, obs, info["obs"]) - assert len(wrapped_env.observation_space.spaces) == len(filter_keys) - assert tuple(wrapped_env.observation_space.spaces.keys()) == tuple(filter_keys) + obs, _, _, _, info = wrapped_env.step(None) + assert len(obs) == 1 and len(info["obs"]) == 3 + check_obs(tuple_env, wrapped_env, obs, info["obs"]) - # Check that the added space item is consistent with the added observation. - observation, info = wrapped_env.reset() - assert len(observation) == len(filter_keys) - assert isinstance(info, dict) - @pytest.mark.parametrize("filter_keys,error_type,error_match", ERROR_TEST_CASES) - def test_raises_with_incorrect_arguments( - self, filter_keys, error_type, error_match - ): - env = FakeEnvironment(observation_keys=("key1", "key2")) +@pytest.mark.parametrize( + "filter_keys, error_type, error_match", + ( + ( + "key", + ValueError, + "All the `filter_keys` must be included in the observation space.", + ), + ( + False, + TypeError, + "Expects `filter_keys` to be a Sequence, actual type: ", + ), + ( + 1, + TypeError, + "Expects `filter_keys` to be a Sequence, actual type: ", + ), + ( + (), + ValueError, + "The observation space is empty due to filtering all of the keys", + ), + ), +) +def test_incorrect_arguments(filter_keys, error_type, error_match): + env = GenericTestEnv( + observation_space=Dict(key_1=Box(0, 1), key_2=Box(2, 3)), + ) - with pytest.raises(error_type, match=error_match): - FilterObservation(env, filter_keys=filter_keys) + with pytest.raises(error_type, match=error_match): + FilterObservation(env, filter_keys=filter_keys) diff --git a/tests/wrappers/test_flatten.py b/tests/wrappers/test_flatten.py deleted file mode 100644 index 9c6f08022..000000000 --- a/tests/wrappers/test_flatten.py +++ /dev/null @@ -1,98 +0,0 @@ -"""Tests for the flatten observation wrapper.""" - -from collections import OrderedDict -from typing import Optional - -import numpy as np -import pytest - -import gymnasium as gym -from gymnasium.spaces import Box, Dict, flatten, unflatten -from gymnasium.wrappers import FlattenObservation - - -class FakeEnvironment(gym.Env): - def __init__(self, observation_space): - self.observation_space = observation_space - - def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): - super().reset(seed=seed) - self.observation = self.observation_space.sample() - return self.observation, {} - - -OBSERVATION_SPACES = ( - ( - Dict( - OrderedDict( - [ - ("key1", Box(shape=(2, 3), low=0, high=0, dtype=np.float32)), - ("key2", Box(shape=(), low=1, high=1, dtype=np.float32)), - ("key3", Box(shape=(2,), low=2, high=2, dtype=np.float32)), - ] - ) - ), - True, - ), - ( - Dict( - OrderedDict( - [ - ("key2", Box(shape=(), low=0, high=0, dtype=np.float32)), - ("key3", Box(shape=(2,), low=1, high=1, dtype=np.float32)), - ("key1", Box(shape=(2, 3), low=2, high=2, dtype=np.float32)), - ] - ) - ), - True, - ), - ( - Dict( - { - "key1": Box(shape=(2, 3), low=-1, high=1, dtype=np.float32), - "key2": Box(shape=(), low=-1, high=1, dtype=np.float32), - "key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32), - } - ), - False, - ), -) - - -class TestFlattenEnvironment: - @pytest.mark.parametrize("observation_space, ordered_values", OBSERVATION_SPACES) - def test_flattened_environment(self, observation_space, ordered_values): - """ - make sure that flattened observations occur in the order expected - """ - env = FakeEnvironment(observation_space=observation_space) - wrapped_env = FlattenObservation(env) - flattened, info = wrapped_env.reset() - - unflattened = unflatten(env.observation_space, flattened) - original = env.observation - - self._check_observations(original, flattened, unflattened, ordered_values) - - @pytest.mark.parametrize("observation_space, ordered_values", OBSERVATION_SPACES) - def test_flatten_unflatten(self, observation_space, ordered_values): - """ - test flatten and unflatten functions directly - """ - original = observation_space.sample() - - flattened = flatten(observation_space, original) - unflattened = unflatten(observation_space, flattened) - - self._check_observations(original, flattened, unflattened, ordered_values) - - def _check_observations(self, original, flattened, unflattened, ordered_values): - # make sure that unflatten(flatten(original)) == original - assert set(unflattened.keys()) == set(original.keys()) - for k, v in original.items(): - np.testing.assert_allclose(unflattened[k], v) - - if ordered_values: - # make sure that the values were flattened in the order they appeared in the - # OrderedDict - np.testing.assert_allclose(sorted(flattened), flattened) diff --git a/tests/wrappers/test_flatten_observation.py b/tests/wrappers/test_flatten_observation.py index 31da398a7..1cf664580 100644 --- a/tests/wrappers/test_flatten_observation.py +++ b/tests/wrappers/test_flatten_observation.py @@ -1,23 +1,29 @@ -import numpy as np -import pytest +"""Test suite for FlattenObservation wrapper.""" -import gymnasium as gym -from gymnasium import spaces +from gymnasium.spaces import Box, Dict, flatten_space from gymnasium.wrappers import FlattenObservation +from tests.testing_env import GenericTestEnv +from tests.wrappers.utils import ( + check_obs, + record_random_obs_reset, + record_random_obs_step, +) -@pytest.mark.parametrize("env_id", ["Blackjack-v1"]) -def test_flatten_observation(env_id): - env = gym.make(env_id, disable_env_checker=True) +def test_flatten_observation_wrapper(): + """Tests the ``FlattenObservation`` wrapper that the observation are flattened correctly.""" + env = GenericTestEnv( + observation_space=Dict(arm=Box(0, 1), head=Box(2, 3)), + reset_func=record_random_obs_reset, + step_func=record_random_obs_step, + ) wrapped_env = FlattenObservation(env) - obs, info = env.reset() - wrapped_obs, wrapped_obs_info = wrapped_env.reset() + assert wrapped_env.observation_space == flatten_space(env.observation_space) + assert wrapped_env.action_space == env.action_space - space = spaces.Tuple((spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2))) - wrapped_space = spaces.Box(0, 1, [32 + 11 + 2], dtype=np.int64) + obs, info = wrapped_env.reset() + check_obs(env, wrapped_env, obs, info["obs"]) - assert space.contains(obs) - assert wrapped_space.contains(wrapped_obs) - assert isinstance(info, dict) - assert isinstance(wrapped_obs_info, dict) + obs, _, _, _, info = wrapped_env.step(None) + check_obs(env, wrapped_env, obs, info["obs"]) diff --git a/tests/wrappers/test_frame_stack.py b/tests/wrappers/test_frame_stack.py deleted file mode 100644 index e8aff980d..000000000 --- a/tests/wrappers/test_frame_stack.py +++ /dev/null @@ -1,53 +0,0 @@ -import numpy as np -import pytest - -import gymnasium as gym -from gymnasium.wrappers import FrameStack - - -try: - import lz4 -except ImportError: - lz4 = None - - -@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1", "CarRacing-v2"]) -@pytest.mark.parametrize("num_stack", [2, 3, 4]) -@pytest.mark.parametrize( - "lz4_compress", - [ - pytest.param( - True, - marks=pytest.mark.skipif( - lz4 is None, reason="Need lz4 to run tests with compression" - ), - ), - False, - ], -) -def test_frame_stack(env_id, num_stack, lz4_compress): - env = gym.make(env_id, disable_env_checker=True) - shape = env.observation_space.shape - env = FrameStack(env, num_stack, lz4_compress) - assert env.observation_space.shape == (num_stack,) + shape - assert env.observation_space.dtype == env.env.observation_space.dtype - - dup = gym.make(env_id, disable_env_checker=True) - - obs, _ = env.reset(seed=0) - dup_obs, _ = dup.reset(seed=0) - assert np.allclose(obs[-1], dup_obs) - - for _ in range(num_stack**2): - action = env.action_space.sample() - dup_obs, _, dup_terminated, dup_truncated, _ = dup.step(action) - obs, _, terminated, truncated, _ = env.step(action) - - assert dup_terminated == terminated - assert dup_truncated == truncated - assert np.allclose(obs[-1], dup_obs) - - if terminated or truncated: - break - - assert len(obs) == num_stack diff --git a/tests/experimental/wrappers/test_frame_stack_observation.py b/tests/wrappers/test_frame_stack_observation.py similarity index 81% rename from tests/experimental/wrappers/test_frame_stack_observation.py rename to tests/wrappers/test_frame_stack_observation.py index 080a05f6e..33cfdbc2c 100644 --- a/tests/experimental/wrappers/test_frame_stack_observation.py +++ b/tests/wrappers/test_frame_stack_observation.py @@ -1,18 +1,14 @@ -"""Test suite for FrameStackObservationV0.""" +"""Test suite for FrameStackObservation wrapper.""" import re import pytest import gymnasium as gym -from gymnasium.experimental.vector.utils import iterate -from gymnasium.experimental.wrappers import FrameStackObservationV0 -from gymnasium.experimental.wrappers.utils import create_zero_array from gymnasium.utils.env_checker import data_equivalence -from tests.experimental.wrappers.utils import ( - SEED, - TESTING_OBS_ENVS, - TESTING_OBS_ENVS_IDS, -) +from gymnasium.vector.utils import iterate +from gymnasium.wrappers import FrameStackObservation +from gymnasium.wrappers.utils import create_zero_array +from tests.wrappers.utils import SEED, TESTING_OBS_ENVS, TESTING_OBS_ENVS_IDS @pytest.mark.parametrize("env", TESTING_OBS_ENVS, ids=TESTING_OBS_ENVS_IDS) @@ -29,7 +25,7 @@ def test_env_obs(env, stack_size: int = 3): obs, _, _, _, _ = env.step(env.action_space.sample()) unstacked_obs.append(obs) - env = FrameStackObservationV0(env, stack_size=stack_size) + env = FrameStackObservation(env, stack_size=stack_size) env.action_space.seed(SEED) obs, _ = env.reset(seed=SEED) @@ -59,7 +55,7 @@ def test_stack_size(stack_size: int): zero_obs = create_zero_array(env.observation_space) - env = FrameStackObservationV0(env, stack_size=stack_size) + env = FrameStackObservation(env, stack_size=stack_size) env.action_space.seed(seed=SEED) obs, _ = env.reset(seed=SEED) @@ -85,10 +81,10 @@ def test_stack_size_failures(): "The stack_size is expected to be an integer, actual type: " ), ): - FrameStackObservationV0(env, stack_size=1.0) + FrameStackObservation(env, stack_size=1.0) with pytest.raises( ValueError, match=re.escape("The stack_size needs to be greater than one, actual value: 0"), ): - FrameStackObservationV0(env, stack_size=0) + FrameStackObservation(env, stack_size=0) diff --git a/tests/wrappers/test_gray_scale_observation.py b/tests/wrappers/test_gray_scale_observation.py index cea3fe96d..768be31ca 100644 --- a/tests/wrappers/test_gray_scale_observation.py +++ b/tests/wrappers/test_gray_scale_observation.py @@ -1,26 +1,38 @@ -import pytest +"""Test suite for GrayscaleObservation wrapper.""" +import numpy as np -import gymnasium as gym -from gymnasium import spaces -from gymnasium.wrappers import GrayScaleObservation +from gymnasium.spaces import Box +from gymnasium.wrappers import GrayscaleObservation +from tests.testing_env import GenericTestEnv +from tests.wrappers.utils import ( + check_obs, + record_random_obs_reset, + record_random_obs_step, +) -@pytest.mark.parametrize("env_id", ["CarRacing-v2"]) -@pytest.mark.parametrize("keep_dim", [True, False]) -def test_gray_scale_observation(env_id, keep_dim): - rgb_env = gym.make(env_id, disable_env_checker=True) +def test_grayscale_observation_wrapper(): + """Tests the ``GrayscaleObservation`` that the observation is grayscale.""" + env = GenericTestEnv( + observation_space=Box(0, 255, shape=(25, 25, 3), dtype=np.uint8), + reset_func=record_random_obs_reset, + step_func=record_random_obs_step, + ) + wrapped_env = GrayscaleObservation(env) - assert isinstance(rgb_env.observation_space, spaces.Box) - assert len(rgb_env.observation_space.shape) == 3 - assert rgb_env.observation_space.shape[-1] == 3 + obs, info = wrapped_env.reset() + check_obs(env, wrapped_env, obs, info["obs"]) + assert obs.shape == (25, 25) - wrapped_env = GrayScaleObservation(rgb_env, keep_dim=keep_dim) - assert isinstance(wrapped_env.observation_space, spaces.Box) - if keep_dim: - assert len(wrapped_env.observation_space.shape) == 3 - assert wrapped_env.observation_space.shape[-1] == 1 - else: - assert len(wrapped_env.observation_space.shape) == 2 + obs, _, _, _, info = wrapped_env.step(None) + check_obs(env, wrapped_env, obs, info["obs"]) - wrapped_obs, info = wrapped_env.reset() - assert wrapped_obs in wrapped_env.observation_space + # Keep_dim + wrapped_env = GrayscaleObservation(env, keep_dim=True) + + obs, info = wrapped_env.reset() + check_obs(env, wrapped_env, obs, info["obs"]) + assert obs.shape == (25, 25, 1) + + obs, _, _, _, info = wrapped_env.step(None) + check_obs(env, wrapped_env, obs, info["obs"]) diff --git a/tests/wrappers/test_human_rendering.py b/tests/wrappers/test_human_rendering.py index 0583eae0b..d9958840e 100644 --- a/tests/wrappers/test_human_rendering.py +++ b/tests/wrappers/test_human_rendering.py @@ -1,3 +1,4 @@ +"""Test suite of HumanRendering wrapper.""" import re import pytest diff --git a/tests/wrappers/test_import_wrappers.py b/tests/wrappers/test_import_wrappers.py new file mode 100644 index 000000000..81ba67d5f --- /dev/null +++ b/tests/wrappers/test_import_wrappers.py @@ -0,0 +1,51 @@ +"""Test suite for import wrappers.""" +import re + +import pytest + +import gymnasium +import gymnasium.wrappers as wrappers +from gymnasium.wrappers import __all__ + + +def test_import_wrappers(): + """Test that all wrappers can be imported.""" + # Test that an invalid wrapper raises an AttributeError + with pytest.raises( + AttributeError, + match=re.escape( + "module 'gymnasium.wrappers' has no attribute 'NonexistentWrapper'" + ), + ): + getattr(wrappers, "NonexistentWrapper") + + +@pytest.mark.parametrize("wrapper_name", __all__) +def test_all_wrappers_shortened(wrapper_name): + """Check that each element of the `__all__` wrappers can be loaded, provided dependencies are installed.""" + try: + assert getattr(gymnasium.wrappers, wrapper_name) is not None + except gymnasium.error.DependencyNotInstalled as e: + pytest.skip(str(e)) + + +def test_wrapper_vector(): + assert gymnasium.wrappers.vector is not None + + +@pytest.mark.parametrize( + "wrapper_name", + ("AutoResetWrapper", "FrameStack", "PixelObservationWrapper", "VectorListInfo"), +) +def test_renamed_wrappers(wrapper_name): + with pytest.raises( + AttributeError, match=f"{wrapper_name!r} has been renamed with" + ) as err_message: + getattr(wrappers, wrapper_name) + + new_wrapper_name = err_message.value.args[0][len(wrapper_name) + 35 : -1] + if "vector." in new_wrapper_name: + no_vector_wrapper_name = new_wrapper_name[len("vector.") :] + assert getattr(gymnasium.wrappers.vector, no_vector_wrapper_name) + else: + assert getattr(gymnasium.wrappers, new_wrapper_name) diff --git a/tests/experimental/wrappers/test_jax_to_numpy.py b/tests/wrappers/test_jax_to_numpy.py similarity index 93% rename from tests/experimental/wrappers/test_jax_to_numpy.py rename to tests/wrappers/test_jax_to_numpy.py index 4c5e671f1..11311c66e 100644 --- a/tests/experimental/wrappers/test_jax_to_numpy.py +++ b/tests/wrappers/test_jax_to_numpy.py @@ -1,4 +1,4 @@ -"""Test suite for JaxToNumpyV0.""" +"""Test suite for JaxToNumpy wrapper.""" import numpy as np import pytest @@ -7,12 +7,12 @@ import pytest jax = pytest.importorskip("jax") jnp = pytest.importorskip("jax.numpy") -from gymnasium.experimental.wrappers.jax_to_numpy import ( # noqa: E402 - JaxToNumpyV0, +from gymnasium.utils.env_checker import data_equivalence # noqa: E402 +from gymnasium.wrappers.jax_to_numpy import ( # noqa: E402 + JaxToNumpy, jax_to_numpy, numpy_to_jax, ) -from gymnasium.utils.env_checker import data_equivalence # noqa: E402 from tests.testing_env import GenericTestEnv # noqa: E402 @@ -99,12 +99,14 @@ def test_jax_to_numpy_wrapper(): assert isinstance(info, dict) and isinstance(info["data"], jax.Array) # Check that the wrapped version is correct. - numpy_env = JaxToNumpyV0(jax_env) + numpy_env = JaxToNumpy(jax_env) obs, info = numpy_env.reset() assert isinstance(obs, np.ndarray) assert isinstance(info, dict) and isinstance(info["data"], np.ndarray) - obs, reward, terminated, truncated, info = numpy_env.step(np.array([1, 2])) + obs, reward, terminated, truncated, info = numpy_env.step( + np.array([1, 2], dtype=np.int32) + ) assert isinstance(obs, np.ndarray) assert isinstance(reward, float) assert isinstance(terminated, bool) and isinstance(truncated, bool) diff --git a/tests/experimental/wrappers/test_jax_to_torch.py b/tests/wrappers/test_jax_to_torch.py similarity index 95% rename from tests/experimental/wrappers/test_jax_to_torch.py rename to tests/wrappers/test_jax_to_torch.py index a74989e09..cb93898de 100644 --- a/tests/experimental/wrappers/test_jax_to_torch.py +++ b/tests/wrappers/test_jax_to_torch.py @@ -1,4 +1,4 @@ -"""Test suite for TorchToJaxV0.""" +"""Test suite for TorchToJax wrapper.""" import numpy as np import pytest @@ -8,8 +8,8 @@ jax = pytest.importorskip("jax") jnp = pytest.importorskip("jax.numpy") torch = pytest.importorskip("torch") -from gymnasium.experimental.wrappers.jax_to_torch import ( # noqa: E402 - JaxToTorchV0, +from gymnasium.wrappers.jax_to_torch import ( # noqa: E402 + JaxToTorch, jax_to_torch, torch_to_jax, ) @@ -100,7 +100,7 @@ def test_jax_to_torch_wrapper(): assert isinstance(info, dict) and isinstance(info["data"], jax.Array) # Check that the wrapped version is correct. - wrapped_env = JaxToTorchV0(env) + wrapped_env = JaxToTorch(env) obs, info = wrapped_env.reset() assert isinstance(obs, torch.Tensor) assert isinstance(info, dict) and isinstance(info["data"], torch.Tensor) diff --git a/tests/experimental/wrappers/test_lambda_action.py b/tests/wrappers/test_lambda_action.py similarity index 68% rename from tests/experimental/wrappers/test_lambda_action.py rename to tests/wrappers/test_lambda_action.py index 81429ee3e..88975b755 100644 --- a/tests/experimental/wrappers/test_lambda_action.py +++ b/tests/wrappers/test_lambda_action.py @@ -1,15 +1,15 @@ -"""Test suite for LambdaActionV0.""" +"""Test suite for LambdaAction wrapper.""" -from gymnasium.experimental.wrappers import LambdaActionV0 from gymnasium.spaces import Box -from tests.experimental.wrappers.utils import record_action_step +from gymnasium.wrappers import TransformAction from tests.testing_env import GenericTestEnv +from tests.wrappers.utils import record_action_step def test_lambda_action_wrapper(): """Tests LambdaAction through checking that the action taken is transformed by function.""" env = GenericTestEnv(step_func=record_action_step) - wrapped_env = LambdaActionV0(env, lambda action: action - 2, Box(2, 3)) + wrapped_env = TransformAction(env, lambda action: action - 2, Box(2, 3)) sampled_action = wrapped_env.action_space.sample() assert sampled_action not in env.action_space diff --git a/tests/experimental/wrappers/test_lambda_observation.py b/tests/wrappers/test_lambda_observation.py similarity index 66% rename from tests/experimental/wrappers/test_lambda_observation.py rename to tests/wrappers/test_lambda_observation.py index af26e7b89..3e4c09074 100644 --- a/tests/experimental/wrappers/test_lambda_observation.py +++ b/tests/wrappers/test_lambda_observation.py @@ -1,15 +1,11 @@ -"""Test suite for lambda observation wrappers.""" +"""Test suite for LambdaObservation wrappers.""" import numpy as np -from gymnasium.experimental.wrappers import LambdaObservationV0 from gymnasium.spaces import Box -from tests.experimental.wrappers.utils import ( - check_obs, - record_action_as_obs_step, - record_obs_reset, -) +from gymnasium.wrappers import TransformObservation from tests.testing_env import GenericTestEnv +from tests.wrappers.utils import check_obs, record_action_as_obs_step, record_obs_reset def test_lambda_observation_wrapper(): @@ -17,7 +13,7 @@ def test_lambda_observation_wrapper(): env = GenericTestEnv( reset_func=record_obs_reset, step_func=record_action_as_obs_step ) - wrapped_env = LambdaObservationV0(env, lambda _obs: _obs + 2, Box(2, 3)) + wrapped_env = TransformObservation(env, lambda _obs: _obs + 2, Box(2, 3)) obs, info = wrapped_env.reset(options={"obs": np.array([0], dtype=np.float32)}) check_obs(env, wrapped_env, obs, info["obs"]) diff --git a/tests/wrappers/test_lambda_reward.py b/tests/wrappers/test_lambda_reward.py new file mode 100644 index 000000000..8beb746f3 --- /dev/null +++ b/tests/wrappers/test_lambda_reward.py @@ -0,0 +1,15 @@ +"""Test suite for LambdaReward wrapper.""" + +from gymnasium.wrappers import TransformReward +from tests.testing_env import GenericTestEnv +from tests.wrappers.utils import record_action_as_record_step + + +def test_lambda_reward(): + env = GenericTestEnv(step_func=record_action_as_record_step) + wrapped_env = TransformReward(env, lambda r: 2 * r + 1) + + _, rew, _, _, _ = wrapped_env.step(0) + assert rew == 1 + _, rew, _, _, _ = wrapped_env.step(1) + assert rew == 3 diff --git a/tests/experimental/wrappers/test_max_and_skip_observation.py b/tests/wrappers/test_max_and_skip_observation.py similarity index 64% rename from tests/experimental/wrappers/test_max_and_skip_observation.py rename to tests/wrappers/test_max_and_skip_observation.py index bcde7c1d0..0be057de8 100644 --- a/tests/experimental/wrappers/test_max_and_skip_observation.py +++ b/tests/wrappers/test_max_and_skip_observation.py @@ -1,25 +1,29 @@ -"""Test suite for MaxAndSkipObservationV0.""" +"""Test suite for MaxAndSkipObservation wrapper.""" import re import pytest import gymnasium as gym -from gymnasium.experimental.wrappers import MaxAndSkipObservationV0 +from gymnasium.wrappers import MaxAndSkipObservation def test_max_and_skip_obs(skip: int = 4): """Test MaxAndSkipObservationV0.""" env = gym.make("CartPole-v1") - env = MaxAndSkipObservationV0(env, skip=skip) + env = MaxAndSkipObservation(env, skip=skip) obs, _ = env.reset() assert obs in env.observation_space for i in range(10): - obs, _, _, _, _ = env.step(env.action_space.sample()) + obs, _, term, trunc, _ = env.step(env.action_space.sample()) assert obs in env.observation_space + if term or trunc: + obs, _ = env.reset() + assert obs in env.observation_space + def test_skip_size_failures(): """Test the error raised by the MaxAndSkipObservation.""" @@ -31,7 +35,7 @@ def test_skip_size_failures(): "The skip is expected to be an integer, actual type: " ), ): - MaxAndSkipObservationV0(env, skip=1.0) + MaxAndSkipObservation(env, skip=1.0) with pytest.raises( ValueError, @@ -39,4 +43,4 @@ def test_skip_size_failures(): "The skip value needs to be equal or greater than two, actual value: 0" ), ): - MaxAndSkipObservationV0(env, skip=0) + MaxAndSkipObservation(env, skip=0) diff --git a/tests/wrappers/test_nested_dict.py b/tests/wrappers/test_nested_dict.py deleted file mode 100644 index b4e4b018d..000000000 --- a/tests/wrappers/test_nested_dict.py +++ /dev/null @@ -1,120 +0,0 @@ -"""Tests for the filter observation wrapper.""" -from typing import Optional - -import numpy as np -import pytest - -import gymnasium as gym -from gymnasium.spaces import Box, Dict, Tuple -from gymnasium.wrappers import FilterObservation, FlattenObservation - - -class FakeEnvironment(gym.Env): - def __init__(self, observation_space, render_mode=None): - self.observation_space = observation_space - self.obs_keys = self.observation_space.spaces.keys() - self.action_space = Box(shape=(1,), low=-1, high=1, dtype=np.float32) - self.render_mode = render_mode - - def render(self, mode="human"): - image_shape = (32, 32, 3) - return np.zeros(image_shape, dtype=np.uint8) - - def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): - super().reset(seed=seed) - observation = self.observation_space.sample() - return observation, {} - - def step(self, action): - del action - observation = self.observation_space.sample() - reward, terminal, info = 0.0, False, {} - return observation, reward, terminal, info - - -NESTED_DICT_TEST_CASES = ( - ( - Dict( - { - "key1": Box(shape=(2,), low=-1, high=1, dtype=np.float32), - "key2": Dict( - { - "subkey1": Box(shape=(2,), low=-1, high=1, dtype=np.float32), - "subkey2": Box(shape=(2,), low=-1, high=1, dtype=np.float32), - } - ), - } - ), - (6,), - ), - ( - Dict( - { - "key1": Box(shape=(2, 3), low=-1, high=1, dtype=np.float32), - "key2": Box(shape=(), low=-1, high=1, dtype=np.float32), - "key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32), - } - ), - (9,), - ), - ( - Dict( - { - "key1": Tuple( - ( - Box(shape=(2,), low=-1, high=1, dtype=np.float32), - Box(shape=(2,), low=-1, high=1, dtype=np.float32), - ) - ), - "key2": Box(shape=(), low=-1, high=1, dtype=np.float32), - "key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32), - } - ), - (7,), - ), - ( - Dict( - { - "key1": Tuple((Box(shape=(2,), low=-1, high=1, dtype=np.float32),)), - "key2": Box(shape=(), low=-1, high=1, dtype=np.float32), - "key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32), - } - ), - (5,), - ), - ( - Dict( - { - "key1": Tuple( - (Dict({"key9": Box(shape=(2,), low=-1, high=1, dtype=np.float32)}),) - ), - "key2": Box(shape=(), low=-1, high=1, dtype=np.float32), - "key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32), - } - ), - (5,), - ), -) - - -class TestNestedDictWrapper: - @pytest.mark.parametrize("observation_space, flat_shape", NESTED_DICT_TEST_CASES) - def test_nested_dicts_size(self, observation_space, flat_shape): - env = FakeEnvironment(observation_space=observation_space) - - # Make sure we are testing the right environment for the test. - observation_space = env.observation_space - assert isinstance(observation_space, Dict) - - wrapped_env = FlattenObservation(FilterObservation(env, list(env.obs_keys))) - assert wrapped_env.observation_space.shape == flat_shape - - assert wrapped_env.observation_space.dtype == np.float32 - - @pytest.mark.parametrize("observation_space, flat_shape", NESTED_DICT_TEST_CASES) - def test_nested_dicts_ravel(self, observation_space, flat_shape): - env = FakeEnvironment(observation_space=observation_space) - wrapped_env = FlattenObservation(FilterObservation(env, list(env.obs_keys))) - obs, info = wrapped_env.reset() - assert obs.shape == wrapped_env.observation_space.shape - assert isinstance(info, dict) diff --git a/tests/wrappers/test_normalize.py b/tests/wrappers/test_normalize.py deleted file mode 100644 index d8a549816..000000000 --- a/tests/wrappers/test_normalize.py +++ /dev/null @@ -1,125 +0,0 @@ -from typing import Optional - -import numpy as np -from numpy.testing import assert_almost_equal - -import gymnasium as gym -from gymnasium.wrappers.normalize import NormalizeObservation, NormalizeReward - - -class DummyRewardEnv(gym.Env): - metadata = {} - - def __init__(self, return_reward_idx=0): - self.action_space = gym.spaces.Discrete(2) - self.observation_space = gym.spaces.Box( - low=np.array([-1.0]), high=np.array([1.0]), dtype=np.float64 - ) - self.returned_rewards = [0, 1, 2, 3, 4] - self.return_reward_idx = return_reward_idx - self.t = self.return_reward_idx - - def step(self, action): - self.t += 1 - return ( - np.array([self.t]), - self.t, - self.t == len(self.returned_rewards), - False, - {}, - ) - - def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): - super().reset(seed=seed) - self.t = self.return_reward_idx - return np.array([self.t]), {} - - -def make_env(return_reward_idx): - def thunk(): - env = DummyRewardEnv(return_reward_idx) - return env - - return thunk - - -def test_normalize_observation(): - env = DummyRewardEnv(return_reward_idx=0) - env = NormalizeObservation(env) - env.reset() - env.step(env.action_space.sample()) - assert_almost_equal(env.obs_rms.mean, 0.5, decimal=4) - env.step(env.action_space.sample()) - assert_almost_equal(env.obs_rms.mean, 1.0, decimal=4) - - -def test_normalize_reset_info(): - env = DummyRewardEnv(return_reward_idx=0) - env = NormalizeObservation(env) - obs, info = env.reset() - assert isinstance(obs, np.ndarray) - assert isinstance(info, dict) - - -def test_normalize_return(): - env = DummyRewardEnv(return_reward_idx=0) - env = NormalizeReward(env) - env.reset() - env.step(env.action_space.sample()) - assert_almost_equal( - env.return_rms.mean, - np.mean([1]), # [first return] - decimal=4, - ) - env.step(env.action_space.sample()) - assert_almost_equal( - env.return_rms.mean, - np.mean([2 + env.gamma * 1, 1]), # [second return, first return] - decimal=4, - ) - - -def test_normalize_observation_vector_env(): - env_fns = [make_env(0), make_env(1)] - envs = gym.vector.SyncVectorEnv(env_fns) - envs.reset() - obs, reward, _, _, _ = envs.step(envs.action_space.sample()) - np.testing.assert_almost_equal(obs, np.array([[1], [2]]), decimal=4) - np.testing.assert_almost_equal(reward, np.array([1, 2]), decimal=4) - - env_fns = [make_env(0), make_env(1)] - envs = gym.vector.SyncVectorEnv(env_fns) - envs = NormalizeObservation(envs) - envs.reset() - assert_almost_equal( - envs.obs_rms.mean, - np.mean([0.5]), # the mean of first observations [[0, 1]] - decimal=4, - ) - obs, reward, _, _, _ = envs.step(envs.action_space.sample()) - assert_almost_equal( - envs.obs_rms.mean, - np.mean([1.0]), # the mean of first and second observations [[0, 1], [1, 2]] - decimal=4, - ) - - -def test_normalize_return_vector_env(): - env_fns = [make_env(0), make_env(1)] - envs = gym.vector.SyncVectorEnv(env_fns) - envs = NormalizeReward(envs) - obs = envs.reset() - obs, reward, _, _, _ = envs.step(envs.action_space.sample()) - assert_almost_equal( - envs.return_rms.mean, - np.mean([1.5]), # the mean of first returns [[1, 2]] - decimal=4, - ) - obs, reward, _, _, _ = envs.step(envs.action_space.sample()) - assert_almost_equal( - envs.return_rms.mean, - np.mean( - [[1, 2], [2 + envs.gamma * 1, 3 + envs.gamma * 2]] - ), # the mean of first and second returns [[1, 2], [2 + envs.gamma * 1, 3 + envs.gamma * 2]] - decimal=4, - ) diff --git a/tests/experimental/wrappers/test_normalize_observation.py b/tests/wrappers/test_normalize_observation.py similarity index 81% rename from tests/experimental/wrappers/test_normalize_observation.py rename to tests/wrappers/test_normalize_observation.py index 22889dbe7..1fccc3b97 100644 --- a/tests/experimental/wrappers/test_normalize_observation.py +++ b/tests/wrappers/test_normalize_observation.py @@ -1,12 +1,13 @@ -"""Test suite for NormalizeObservationV0.""" -from gymnasium.experimental.wrappers import NormalizeObservationV0 +"""Test suite for NormalizeObservation wrapper.""" + +from gymnasium.wrappers import NormalizeObservation from tests.testing_env import GenericTestEnv -def test_running_mean_normalize_observation_wrapper(): +def test_update_running_mean_property(): """Tests that the property `_update_running_mean` freezes/continues the running statistics updating.""" env = GenericTestEnv() - wrapped_env = NormalizeObservationV0(env) + wrapped_env = NormalizeObservation(env) # Default value is True assert wrapped_env.update_running_mean diff --git a/tests/wrappers/test_normalize_reward.py b/tests/wrappers/test_normalize_reward.py new file mode 100644 index 000000000..9e467a227 --- /dev/null +++ b/tests/wrappers/test_normalize_reward.py @@ -0,0 +1,83 @@ +"""Test suite for NormalizeReward wrapper.""" + +import numpy as np + +import gymnasium as gym +from gymnasium.core import ActType +from gymnasium.wrappers import NormalizeReward +from tests.testing_env import GenericTestEnv + + +def constant_reward_step_func(self, action: ActType): + return self.observation_space.sample(), 1.0, False, False, {} + + +def test_running_mean_normalize_reward_wrapper(): + """Tests that the property `_update_running_mean` freezes/continues the running statistics updating.""" + env = GenericTestEnv(step_func=constant_reward_step_func) + wrapped_env = NormalizeReward(env) + + # Default value is True + assert wrapped_env.update_running_mean + + wrapped_env.reset() + rms_var_init = wrapped_env.return_rms.var + rms_mean_init = wrapped_env.return_rms.mean + + # Statistics are updated when env.step() + wrapped_env.step(None) + rms_var_updated = wrapped_env.return_rms.var + rms_mean_updated = wrapped_env.return_rms.mean + assert rms_var_init != rms_var_updated + assert rms_mean_init != rms_mean_updated + + # Assure property is set + wrapped_env.update_running_mean = False + assert not wrapped_env.update_running_mean + + # Statistics are frozen + wrapped_env.step(None) + assert rms_var_updated == wrapped_env.return_rms.var + assert rms_mean_updated == wrapped_env.return_rms.mean + + +def test_normalize_reward_wrapper(): + """Tests that the NormalizeReward does not throw an error.""" + # TODO: Functional correctness should be tested + env = GenericTestEnv(step_func=constant_reward_step_func) + wrapped_env = NormalizeReward(env) + wrapped_env.reset() + _, reward, _, _, _ = wrapped_env.step(None) + assert np.ndim(reward) == 0 + env.close() + + +def reward_reset_func(self: gym.Env, seed=None, options=None): + self.rewards = [0, 1, 2, 3, 4] + reward = self.rewards.pop(0) + return np.array([reward]), {"reward": reward} + + +def reward_step_func(self: gym.Env, action): + reward = self.rewards.pop(0) + return np.array([reward]), reward, len(self.rewards) == 0, False, {"reward": reward} + + +def test_normalize_return(): + env = GenericTestEnv(reset_func=reward_reset_func, step_func=reward_step_func) + env = NormalizeReward(env) + env.reset() + + env.step(env.action_space.sample()) + np.testing.assert_almost_equal( + env.return_rms.mean, + np.mean([1]), # [first return] + decimal=4, + ) + + env.step(env.action_space.sample()) + np.testing.assert_almost_equal( + env.return_rms.mean, + np.mean([2 + 1 * env.gamma, 1]), # [second return, first return] + decimal=4, + ) diff --git a/tests/wrappers/test_numpy_to_torch.py b/tests/wrappers/test_numpy_to_torch.py new file mode 100644 index 000000000..8c304ab7d --- /dev/null +++ b/tests/wrappers/test_numpy_to_torch.py @@ -0,0 +1,110 @@ +"""Test suite for NumPyToTorch wrapper.""" + +import numpy as np +import pytest + + +torch = pytest.importorskip("torch") + + +from gymnasium.utils.env_checker import data_equivalence # noqa: E402 +from gymnasium.wrappers.numpy_to_torch import ( # noqa: E402 + NumpyToTorch, + numpy_to_torch, + torch_to_numpy, +) +from tests.testing_env import GenericTestEnv # noqa: E402 + + +@pytest.mark.parametrize( + "value, expected_value", + [ + (1.0, np.array(1.0, dtype=np.float32)), + (2, np.array(2, dtype=np.int64)), + ((3.0, 4), (np.array(3.0, dtype=np.float32), np.array(4, dtype=np.int64))), + ([3.0, 4], [np.array(3.0, dtype=np.float32), np.array(4, dtype=np.int64)]), + ( + { + "a": 6.0, + "b": 7, + }, + {"a": np.array(6.0, dtype=np.float32), "b": np.array(7, dtype=np.int64)}, + ), + (np.array(1.0, dtype=np.float32), np.array(1.0, dtype=np.float32)), + (np.array(1.0, dtype=np.uint8), np.array(1.0, dtype=np.uint8)), + (np.array([1, 2], dtype=np.int32), np.array([1, 2], dtype=np.int32)), + ( + np.array([[1.0], [2.0]], dtype=np.int32), + np.array([[1.0], [2.0]], dtype=np.int32), + ), + ( + { + "a": ( + 1, + np.array(2.0, dtype=np.float32), + np.array([3, 4], dtype=np.int32), + ), + "b": {"c": 5}, + }, + { + "a": ( + np.array(1, dtype=np.int64), + np.array(2.0, dtype=np.float32), + np.array([3, 4], dtype=np.int32), + ), + "b": {"c": np.array(5, dtype=np.int64)}, + }, + ), + ], +) +def test_roundtripping(value, expected_value): + """We test numpy -> torch -> numpy as this is direction in the NumpyToTorch wrapper.""" + torch_value = numpy_to_torch(value) + roundtripped_value = torch_to_numpy(torch_value) + # roundtripped_value = torch_to_numpy(numpy_to_torch(value)) + assert data_equivalence(roundtripped_value, expected_value) + + +def numpy_reset_func(self, seed=None, options=None): + """A Numpy-based reset function.""" + return np.array([1.0, 2.0, 3.0]), {"data": np.array([1, 2, 3])} + + +def numpy_step_func(self, action): + """A Numpy-based step function.""" + assert isinstance(action, np.ndarray), type(action) + return ( + np.array([1, 2, 3]), + 5.0, + True, + False, + {"data": np.array([1.0, 2.0])}, + ) + + +def test_numpy_to_torch(): + """Tests the ``TorchToNumpy`` wrapper.""" + numpy_env = GenericTestEnv(reset_func=numpy_reset_func, step_func=numpy_step_func) + obs, info = numpy_env.reset() + assert isinstance(obs, np.ndarray) + assert isinstance(info, dict) and isinstance(info["data"], np.ndarray) + + obs, reward, terminated, truncated, info = numpy_env.step(np.array([1, 2])) + assert isinstance(obs, np.ndarray) + assert isinstance(reward, float) + assert isinstance(terminated, bool) and isinstance(truncated, bool) + assert isinstance(info, dict) and isinstance(info["data"], np.ndarray) + + # Check that the wrapped version is correct. + torch_env = NumpyToTorch(numpy_env) + + # Check that the reset and step for torch environment are as expected + obs, info = torch_env.reset() + assert isinstance(obs, torch.Tensor) + assert isinstance(info, dict) and isinstance(info["data"], torch.Tensor) + + obs, reward, terminated, truncated, info = torch_env.step(torch.tensor([1, 2])) + assert isinstance(obs, torch.Tensor) + assert isinstance(reward, float) + assert isinstance(terminated, bool) and isinstance(truncated, bool) + assert isinstance(info, dict) and isinstance(info["data"], torch.Tensor) diff --git a/tests/wrappers/test_order_enforcing.py b/tests/wrappers/test_order_enforcing.py index b7b46476e..428827be1 100644 --- a/tests/wrappers/test_order_enforcing.py +++ b/tests/wrappers/test_order_enforcing.py @@ -1,23 +1,12 @@ +"""Test suite for OrderEnforcing wrapper.""" import pytest -import gymnasium as gym from gymnasium.envs.classic_control import CartPoleEnv from gymnasium.error import ResetNeeded from gymnasium.wrappers import OrderEnforcing -from tests.envs.utils import all_testing_env_specs from tests.wrappers.utils import has_wrapper -@pytest.mark.parametrize( - "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs] -) -def test_gym_make_order_enforcing(spec): - """Checks that gym.make wrappers the environment with the OrderEnforcing wrapper.""" - env = gym.make(spec.id, disable_env_checker=True) - - assert has_wrapper(env, OrderEnforcing) - - def test_order_enforcing(): """Checks that the order enforcing works as expected, raising an error before reset is called and not after.""" # The reason for not using gym.make is that all environments are by default wrapped in the order enforcing wrapper diff --git a/tests/wrappers/test_passive_env_checker.py b/tests/wrappers/test_passive_env_checker.py index 52e97f2aa..c5b9f746c 100644 --- a/tests/wrappers/test_passive_env_checker.py +++ b/tests/wrappers/test_passive_env_checker.py @@ -1,3 +1,5 @@ +"""Test suite for PassiveEnvChecker wrapper.""" + import re import warnings @@ -5,7 +7,7 @@ import numpy as np import pytest import gymnasium as gym -from gymnasium.wrappers.env_checker import PassiveEnvChecker +from gymnasium.wrappers import PassiveEnvChecker from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING from tests.envs.utils import all_testing_initialised_envs from tests.testing_env import GenericTestEnv diff --git a/tests/wrappers/test_pixel_observation.py b/tests/wrappers/test_pixel_observation.py deleted file mode 100644 index 7ada4894d..000000000 --- a/tests/wrappers/test_pixel_observation.py +++ /dev/null @@ -1,125 +0,0 @@ -"""Tests for the pixel observation wrapper.""" -from typing import Optional - -import numpy as np -import pytest - -import gymnasium as gym -from gymnasium import spaces -from gymnasium.wrappers.pixel_observation import STATE_KEY, PixelObservationWrapper - - -class FakeEnvironment(gym.Env): - def __init__(self, render_mode="single_rgb_array"): - self.action_space = spaces.Box(shape=(1,), low=-1, high=1, dtype=np.float32) - self.render_mode = render_mode - - def render(self, mode="human", width=32, height=32): - image_shape = (height, width, 3) - return np.zeros(image_shape, dtype=np.uint8) - - def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): - super().reset(seed=seed) - observation = self.observation_space.sample() - return observation, {} - - def step(self, action): - del action - observation = self.observation_space.sample() - reward, terminal, info = 0.0, False, {} - return observation, reward, terminal, info - - -class FakeArrayObservationEnvironment(FakeEnvironment): - def __init__(self, *args, **kwargs): - self.observation_space = spaces.Box( - shape=(2,), low=-1, high=1, dtype=np.float32 - ) - super().__init__(*args, **kwargs) - - -class FakeDictObservationEnvironment(FakeEnvironment): - def __init__(self, *args, **kwargs): - self.observation_space = spaces.Dict( - { - "state": spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32), - } - ) - super().__init__(*args, **kwargs) - - -@pytest.mark.parametrize("pixels_only", (True, False)) -def test_dict_observation(pixels_only): - pixel_key = "rgb" - - env = FakeDictObservationEnvironment() - - # Make sure we are testing the right environment for the test. - observation_space = env.observation_space - assert isinstance(observation_space, spaces.Dict) - - width, height = (320, 240) - - # The wrapper should only add one observation. - wrapped_env = PixelObservationWrapper( - env, - pixel_keys=(pixel_key,), - pixels_only=pixels_only, - render_kwargs={pixel_key: {"width": width, "height": height}}, - ) - - assert isinstance(wrapped_env.observation_space, spaces.Dict) - - if pixels_only: - assert len(wrapped_env.observation_space.spaces) == 1 - assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key] - else: - assert ( - len(wrapped_env.observation_space.spaces) - == len(observation_space.spaces) + 1 - ) - expected_keys = list(observation_space.spaces.keys()) + [pixel_key] - assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys - - # Check that the added space item is consistent with the added observation. - observation, info = wrapped_env.reset() - rgb_observation = observation[pixel_key] - - assert isinstance(info, dict) - assert rgb_observation.shape == (height, width, 3) - assert rgb_observation.dtype == np.uint8 - - -@pytest.mark.parametrize("pixels_only", (True, False)) -def test_single_array_observation(pixels_only): - pixel_key = "depth" - - env = FakeArrayObservationEnvironment() - observation_space = env.observation_space - assert isinstance(observation_space, spaces.Box) - - wrapped_env = PixelObservationWrapper( - env, pixel_keys=(pixel_key,), pixels_only=pixels_only - ) - wrapped_env.observation_space = wrapped_env.observation_space - assert isinstance(wrapped_env.observation_space, spaces.Dict) - - if pixels_only: - assert len(wrapped_env.observation_space.spaces) == 1 - assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key] - else: - assert len(wrapped_env.observation_space.spaces) == 2 - assert list(wrapped_env.observation_space.spaces.keys()) == [ - STATE_KEY, - pixel_key, - ] - - observation, info = wrapped_env.reset() - depth_observation = observation[pixel_key] - - assert isinstance(info, dict) - assert depth_observation.shape == (32, 32, 3) - assert depth_observation.dtype == np.uint8 - - if not pixels_only: - assert isinstance(observation[STATE_KEY], np.ndarray) diff --git a/tests/wrappers/test_record_episode_statistics.py b/tests/wrappers/test_record_episode_statistics.py index f277ea937..06897b80c 100644 --- a/tests/wrappers/test_record_episode_statistics.py +++ b/tests/wrappers/test_record_episode_statistics.py @@ -1,8 +1,9 @@ -import numpy as np +"""Test suite for RecordEpisodeStatistics wrapper.""" + import pytest import gymnasium as gym -from gymnasium.wrappers import RecordEpisodeStatistics, VectorListInfo +from gymnasium.wrappers import RecordEpisodeStatistics @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) @@ -14,8 +15,8 @@ def test_record_episode_statistics(env_id, deque_size): for n in range(5): env.reset() assert env.episode_returns is not None and env.episode_lengths is not None - assert env.episode_returns[0] == 0.0 - assert env.episode_lengths[0] == 0 + assert env.episode_returns == 0.0 + assert env.episode_lengths == 0 assert env.spec is not None for t in range(env.spec.max_episode_steps): _, _, terminated, truncated, info = env.step(env.action_space.sample()) @@ -25,52 +26,3 @@ def test_record_episode_statistics(env_id, deque_size): break assert len(env.return_queue) == deque_size assert len(env.length_queue) == deque_size - - -def test_record_episode_statistics_reset_info(): - env = gym.make("CartPole-v1", disable_env_checker=True) - env = RecordEpisodeStatistics(env) - ob_space = env.observation_space - obs, info = env.reset() - assert ob_space.contains(obs) - assert isinstance(info, dict) - - -@pytest.mark.parametrize( - ("num_envs", "asynchronous"), [(1, False), (1, True), (4, False), (4, True)] -) -def test_record_episode_statistics_with_vectorenv(num_envs, asynchronous): - envs = gym.vector.make( - "CartPole-v1", - render_mode=None, - num_envs=num_envs, - asynchronous=asynchronous, - disable_env_checker=True, - ) - envs = RecordEpisodeStatistics(envs) - max_episode_step = ( - envs.env_fns[0]().spec.max_episode_steps - if asynchronous - else envs.env.envs[0].spec.max_episode_steps - ) - envs.reset() - for _ in range(max_episode_step + 1): - _, _, terminateds, truncateds, infos = envs.step(envs.action_space.sample()) - if any(terminateds) or any(truncateds): - assert "episode" in infos - assert "_episode" in infos - assert all(infos["_episode"] == np.bitwise_or(terminateds, truncateds)) - assert all([item in infos["episode"] for item in ["r", "l", "t"]]) - break - else: - assert "episode" not in infos - assert "_episode" not in infos - - -def test_wrong_wrapping_order(): - envs = gym.vector.make("CartPole-v1", num_envs=3, disable_env_checker=True) - wrapped_env = RecordEpisodeStatistics(VectorListInfo(envs)) - wrapped_env.reset() - - with pytest.raises(AssertionError): - wrapped_env.step(wrapped_env.action_space.sample()) diff --git a/tests/wrappers/test_record_video.py b/tests/wrappers/test_record_video.py index 6395f18f9..b0659cdee 100644 --- a/tests/wrappers/test_record_video.py +++ b/tests/wrappers/test_record_video.py @@ -1,46 +1,64 @@ +"""Test suite for RecordVideo wrapper.""" import os import shutil +from typing import List import gymnasium as gym -from gymnasium.wrappers import capped_cubic_video_schedule +from gymnasium.wrappers import RecordVideo def test_record_video_using_default_trigger(): - env = gym.make( - "CartPole-v1", render_mode="rgb_array_list", disable_env_checker=True - ) - env = gym.wrappers.RecordVideo(env, "videos") + """Test RecordVideo using the default episode trigger.""" + env = gym.make("CartPole-v1", render_mode="rgb_array_list") + env = RecordVideo(env, "videos") env.reset() + episode_count = 0 for _ in range(199): action = env.action_space.sample() _, _, terminated, truncated, _ = env.step(action) if terminated or truncated: env.reset() + episode_count += 1 + env.close() assert os.path.isdir("videos") mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] + assert env.episode_trigger is not None assert len(mp4_files) == sum( - capped_cubic_video_schedule(i) for i in range(env.episode_id + 1) + env.episode_trigger(i) for i in range(episode_count + 1) ) shutil.rmtree("videos") -def test_record_video_reset(): - env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True) - env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0) - ob_space = env.observation_space - obs, info = env.reset() +def test_record_video_while_rendering(): + """Test RecordVideo while calling render and using a _list render mode.""" + env = gym.make("FrozenLake-v1", render_mode="rgb_array_list") + env = RecordVideo(env, "videos") + env.reset() + episode_count = 0 + for _ in range(199): + action = env.action_space.sample() + _, _, terminated, truncated, _ = env.step(action) + env.render() + if terminated or truncated: + env.reset() + episode_count += 1 + env.close() assert os.path.isdir("videos") + mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] + assert env.episode_trigger is not None + assert len(mp4_files) == sum( + env.episode_trigger(i) for i in range(episode_count + 1) + ) shutil.rmtree("videos") - assert ob_space.contains(obs) - assert isinstance(info, dict) def test_record_video_step_trigger(): + """Test RecordVideo defining step trigger function.""" env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True) env._max_episode_steps = 20 - env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0) + env = RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0) env.reset() for _ in range(199): action = env.action_space.sample() @@ -50,37 +68,73 @@ def test_record_video_step_trigger(): env.close() assert os.path.isdir("videos") mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] - assert len(mp4_files) == 2 shutil.rmtree("videos") + assert len(mp4_files) == 2 -def make_env(gym_id, seed, **kwargs): - def thunk(): - env = gym.make(gym_id, disable_env_checker=True, **kwargs) - env._max_episode_steps = 20 - if seed == 1: - env = gym.wrappers.RecordVideo( - env, "videos", step_trigger=lambda x: x % 100 == 0 - ) - return env - - return thunk - - -def test_record_video_within_vector(): - envs = gym.vector.SyncVectorEnv( - [make_env("CartPole-v1", 1 + i, render_mode="rgb_array") for i in range(2)] +def test_record_video_both_trigger(): + """Test RecordVideo defining both step and episode trigger functions.""" + env = gym.make( + "CartPole-v1", render_mode="rgb_array_list", disable_env_checker=True ) - envs = gym.wrappers.RecordEpisodeStatistics(envs) - envs.reset() - for i in range(199): - _, _, _, _, infos = envs.step(envs.action_space.sample()) - - # break when every env is done - if "episode" in infos and all(infos["_episode"]): - print(f"episode_reward={infos['episode']['r']}") - + env._max_episode_steps = 20 + env = RecordVideo( + env, + "videos", + step_trigger=lambda x: x == 100, + episode_trigger=lambda x: x == 0 or x == 3, + ) + env.reset() + for _ in range(199): + action = env.action_space.sample() + _, _, terminated, truncated, _ = env.step(action) + if terminated or truncated: + env.reset() + env.close() assert os.path.isdir("videos") mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] - assert len(mp4_files) == 2 + shutil.rmtree("videos") + assert len(mp4_files) == 3 + + +def test_record_video_length(): + """Test if argument video_length of RecordVideo works properly.""" + env = gym.make("CartPole-v1", render_mode="rgb_array_list") + env._max_episode_steps = 20 + env = RecordVideo(env, "videos", step_trigger=lambda x: x == 0, video_length=10) + env.reset() + for _ in range(10): + _, _, term, trunc, _ = env.step(env.action_space.sample()) + if term or trunc: + break + + assert env.recording + action = env.action_space.sample() + env.step(action) + assert not env.recording + env.close() + assert os.path.isdir("videos") + mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] + assert len(mp4_files) == 1 + shutil.rmtree("videos") + + +def test_rendering_works(): + """Test if render output is as expected when the env is wrapped with RecordVideo.""" + env = gym.make("CartPole-v1", render_mode="rgb_array_list") + env._max_episode_steps = 20 + env = RecordVideo(env, "videos") + env.reset() + n_steps = 10 + for _ in range(n_steps): + action = env.action_space.sample() + env.step(action) + + render_out = env.render() + assert isinstance(render_out, List) + assert len(render_out) == n_steps + 1 + render_out = env.render() + assert isinstance(render_out, List) + assert len(render_out) == 0 + env.close() shutil.rmtree("videos") diff --git a/tests/wrappers/test_render_observation.py b/tests/wrappers/test_render_observation.py new file mode 100644 index 000000000..629c1dfbe --- /dev/null +++ b/tests/wrappers/test_render_observation.py @@ -0,0 +1,96 @@ +"""Test suite for RenderObservation wrapper.""" +import numpy as np +import pytest + +from gymnasium import spaces +from gymnasium.wrappers import RenderObservation +from tests.testing_env import GenericTestEnv + + +STATE_KEY = "state" + + +def image_render_func(self): + return np.zeros((32, 32, 3), dtype=np.uint8) + + +@pytest.mark.parametrize("pixels_only", (True, False)) +def test_dict_observation(pixels_only, pixel_key="rgb"): + env = GenericTestEnv( + observation_space=spaces.Dict( + state=spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32) + ), + render_mode="rgb_array", + render_func=image_render_func, + ) + + # Make sure we are testing the right environment for the test. + assert isinstance(env.observation_space, spaces.Dict) + + # width, height = (320, 240) + + # The wrapper should only add one observation. + wrapped_env = RenderObservation( + env, + render_key=pixel_key, + render_only=pixels_only, + # render_kwargs={pixel_key: {"width": width, "height": height}}, + ) + obs, info = wrapped_env.reset() + if pixels_only: + assert isinstance(wrapped_env.observation_space, spaces.Box) + assert isinstance(obs, np.ndarray) + + rendered_obs = obs + else: + assert isinstance(wrapped_env.observation_space, spaces.Dict) + + expected_keys = [pixel_key] + list(env.observation_space.spaces.keys()) + assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys + + assert isinstance(obs, dict) + rendered_obs = obs[pixel_key] + + # Check that the added space item is consistent with the added observation. + # assert rendered_obs.shape == (height, width, 3) + assert rendered_obs.ndim == 3 + assert rendered_obs.dtype == np.uint8 + + +@pytest.mark.parametrize("pixels_only", (True, False)) +def test_single_array_observation(pixels_only): + pixel_key = "depth" + + env = GenericTestEnv( + observation_space=spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32), + render_mode="rgb_array", + render_func=image_render_func, + ) + assert isinstance(env.observation_space, spaces.Box) + + # The wrapper should only add one observation. + wrapped_env = RenderObservation( + env, + render_key=pixel_key, + render_only=pixels_only, + # render_kwargs={pixel_key: {"width": width, "height": height}}, + ) + obs, info = wrapped_env.reset() + if pixels_only: + assert isinstance(wrapped_env.observation_space, spaces.Box) + assert isinstance(obs, np.ndarray) + + rendered_obs = obs + else: + assert isinstance(wrapped_env.observation_space, spaces.Dict) + + expected_keys = [pixel_key, "state"] + assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys + + assert isinstance(obs, dict) + rendered_obs = obs[pixel_key] + + # Check that the added space item is consistent with the added observation. + # assert rendered_obs.shape == (height, width, 3) + assert rendered_obs.ndim == 3 + assert rendered_obs.dtype == np.uint8 diff --git a/tests/wrappers/test_rescale_action.py b/tests/wrappers/test_rescale_action.py index 127f53f6f..fcd8b098d 100644 --- a/tests/wrappers/test_rescale_action.py +++ b/tests/wrappers/test_rescale_action.py @@ -1,31 +1,38 @@ +"""Test suite for RescaleAction wrapper.""" import numpy as np -import pytest -import gymnasium as gym +from gymnasium.spaces import Box from gymnasium.wrappers import RescaleAction +from tests.testing_env import GenericTestEnv +from tests.wrappers.utils import record_action_step -def test_rescale_action(): - env = gym.make("CartPole-v1", disable_env_checker=True) - with pytest.raises(AssertionError): - env = RescaleAction(env, -1, 1) - del env - - env = gym.make("Pendulum-v1", disable_env_checker=True) - wrapped_env = RescaleAction( - gym.make("Pendulum-v1", disable_env_checker=True), -1, 1 +def test_rescale_action_wrapper(): + """Test that the action is rescale within a min / max bound.""" + env = GenericTestEnv( + step_func=record_action_step, + action_space=Box(np.array([0, 1]), np.array([1, 3])), ) + wrapped_env = RescaleAction( + env, min_action=np.array([-5, 0]), max_action=np.array([5, 1]) + ) + assert wrapped_env.action_space == Box(np.array([-5, 0]), np.array([5, 1])) - seed = 0 + for sample_action, expected_action in ( + ( + np.array([0.0, 0.5], dtype=np.float32), + np.array([0.5, 2.0], dtype=np.float32), + ), + ( + np.array([-5.0, 0.0], dtype=np.float32), + np.array([0.0, 1.0], dtype=np.float32), + ), + ( + np.array([5.0, 1.0], dtype=np.float32), + np.array([1.0, 3.0], dtype=np.float32), + ), + ): + assert sample_action in wrapped_env.action_space - obs, info = env.reset(seed=seed) - wrapped_obs, wrapped_obs_info = wrapped_env.reset(seed=seed) - assert np.allclose(obs, wrapped_obs) - - obs, reward, _, _, _ = env.step([1.5]) - with pytest.raises(AssertionError): - wrapped_env.step([1.5]) - wrapped_obs, wrapped_reward, _, _, _ = wrapped_env.step([0.75]) - - assert np.allclose(obs, wrapped_obs) - assert np.allclose(reward, wrapped_reward) + _, _, _, _, info = wrapped_env.step(sample_action) + assert np.all(info["action"] == expected_action) diff --git a/tests/experimental/wrappers/test_rescale_observation.py b/tests/wrappers/test_rescale_observation.py similarity index 85% rename from tests/experimental/wrappers/test_rescale_observation.py rename to tests/wrappers/test_rescale_observation.py index fffb4a9b8..3178fa9b9 100644 --- a/tests/experimental/wrappers/test_rescale_observation.py +++ b/tests/wrappers/test_rescale_observation.py @@ -1,14 +1,10 @@ -"""Test suite for RescaleObservationV0.""" +"""Test suite for RescaleObservation wrapper.""" import numpy as np -from gymnasium.experimental.wrappers import RescaleObservationV0 from gymnasium.spaces import Box -from tests.experimental.wrappers.utils import ( - check_obs, - record_action_as_obs_step, - record_obs_reset, -) +from gymnasium.wrappers import RescaleObservation from tests.testing_env import GenericTestEnv +from tests.wrappers.utils import check_obs, record_action_as_obs_step, record_obs_reset def test_rescale_observation(): @@ -20,7 +16,7 @@ def test_rescale_observation(): reset_func=record_obs_reset, step_func=record_action_as_obs_step, ) - wrapped_env = RescaleObservationV0( + wrapped_env = RescaleObservation( env, min_obs=np.array([-5, 0], dtype=np.float32), max_obs=np.array([5, 1], dtype=np.float32), diff --git a/tests/experimental/wrappers/test_reshape_observation.py b/tests/wrappers/test_reshape_observation.py similarity index 76% rename from tests/experimental/wrappers/test_reshape_observation.py rename to tests/wrappers/test_reshape_observation.py index f42f759bd..2458daf5f 100644 --- a/tests/experimental/wrappers/test_reshape_observation.py +++ b/tests/wrappers/test_reshape_observation.py @@ -1,12 +1,12 @@ -"""Test suite for ReshapeObservationv0.""" -from gymnasium.experimental.wrappers import ReshapeObservationV0 +"""Test suite for ReshapeObservation wrapper.""" from gymnasium.spaces import Box -from tests.experimental.wrappers.utils import ( +from gymnasium.wrappers import ReshapeObservation +from tests.testing_env import GenericTestEnv +from tests.wrappers.utils import ( check_obs, record_random_obs_reset, record_random_obs_step, ) -from tests.testing_env import GenericTestEnv def test_reshape_observation_wrapper(): @@ -16,7 +16,7 @@ def test_reshape_observation_wrapper(): reset_func=record_random_obs_reset, step_func=record_random_obs_step, ) - wrapped_env = ReshapeObservationV0(env, (6, 2)) + wrapped_env = ReshapeObservation(env, (6, 2)) obs, info = wrapped_env.reset() check_obs(env, wrapped_env, obs, info["obs"]) diff --git a/tests/wrappers/test_resize_observation.py b/tests/wrappers/test_resize_observation.py index ed0d587b7..233612930 100644 --- a/tests/wrappers/test_resize_observation.py +++ b/tests/wrappers/test_resize_observation.py @@ -1,38 +1,65 @@ +"""Test suite for ResizeObservation wrapper.""" +from __future__ import annotations + +import numpy as np import pytest import gymnasium as gym -from gymnasium import spaces -from gymnasium.wrappers import GrayScaleObservation, ResizeObservation +from gymnasium.spaces import Box +from gymnasium.wrappers import ResizeObservation +from tests.testing_env import GenericTestEnv +from tests.wrappers.utils import ( + check_obs, + record_random_obs_reset, + record_random_obs_step, +) -@pytest.mark.parametrize("env_id", ["CarRacing-v2"]) -@pytest.mark.parametrize("shape", [16, 32, (8, 5), [10, 7]]) -def test_resize_observation(env_id, shape): - base_env = gym.make(env_id, disable_env_checker=True) - env = ResizeObservation(base_env, shape) +@pytest.mark.parametrize( + "env", + ( + GenericTestEnv( + observation_space=Box(0, 255, shape=(60, 60, 3), dtype=np.uint8), + reset_func=record_random_obs_reset, + step_func=record_random_obs_step, + ), + GenericTestEnv( + observation_space=Box(0, 255, shape=(60, 60), dtype=np.uint8), + reset_func=record_random_obs_reset, + step_func=record_random_obs_step, + ), + ), +) +def test_resize_observation_wrapper(env): + """Test the ``ResizeObservation`` that the observation has changed size.""" - assert isinstance(env.observation_space, spaces.Box) - assert env.observation_space.shape[-1] == 3 - obs, _ = env.reset() - if isinstance(shape, int): - assert env.observation_space.shape[:2] == (shape, shape) - assert obs.shape == (shape, shape, 3) - else: - assert env.observation_space.shape[:2] == tuple(shape) - assert obs.shape == tuple(shape) + (3,) + wrapped_env = ResizeObservation(env, (25, 25)) + assert isinstance(wrapped_env.observation_space, Box) + assert wrapped_env.observation_space.shape[:2] == (25, 25) - # test two-dimensional input by grayscaling the observation - gray_env = GrayScaleObservation(base_env, keep_dim=False) - env = ResizeObservation(gray_env, shape) - obs, _ = env.reset() - if isinstance(shape, int): - assert env.observation_space.shape == obs.shape == (shape, shape) - else: - assert env.observation_space.shape == obs.shape == tuple(shape) + obs, info = wrapped_env.reset() + check_obs(env, wrapped_env, obs, info["obs"]) + + obs, _, _, _, info = wrapped_env.step(None) + check_obs(env, wrapped_env, obs, info["obs"]) + + +@pytest.mark.parametrize("shape", ((10, 10), (20, 20), (60, 60), (100, 100))) +def test_resize_shapes(shape: tuple[int, int]): + env = ResizeObservation(gym.make("CarRacing-v2"), shape) + assert env.observation_space == Box( + low=0, high=255, shape=shape + (3,), dtype=np.uint8 + ) + + obs, info = env.reset() + assert obs in env.observation_space + obs, _, _, _, _ = env.step(env.action_space.sample()) + assert obs in env.observation_space def test_invalid_input(): - env = gym.make("CarRacing-v2", disable_env_checker=True) + env = gym.make("CarRacing-v2") + with pytest.raises(AssertionError): ResizeObservation(env, ()) with pytest.raises(AssertionError): @@ -40,8 +67,8 @@ def test_invalid_input(): with pytest.raises(AssertionError): ResizeObservation(env, (1, 1, 1, 1)) with pytest.raises(AssertionError): - ResizeObservation(env, -1) + ResizeObservation(env, (-1, 1)) with pytest.raises(AssertionError): - ResizeObservation(gym.make("CartPole-v1", disable_env_checker=True), 1) + ResizeObservation(gym.make("CartPole-v1"), (1, 1)) with pytest.raises(AssertionError): - ResizeObservation(gym.make("Blackjack-v1", disable_env_checker=True), 1) + ResizeObservation(gym.make("Blackjack-v1"), (1, 1)) diff --git a/tests/wrappers/test_step_compatibility.py b/tests/wrappers/test_step_compatibility.py deleted file mode 100644 index 79cf6451c..000000000 --- a/tests/wrappers/test_step_compatibility.py +++ /dev/null @@ -1,98 +0,0 @@ -import numpy as np -import pytest - -import gymnasium as gym -from gymnasium.spaces import Discrete -from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv -from gymnasium.wrappers import StepAPICompatibility - - -class OldStepEnv(gym.Env): - def __init__(self): - self.action_space = Discrete(2) - self.observation_space = Discrete(2) - - def step(self, action): - obs = self.observation_space.sample() - rew = 0 - done = False - info = {} - return obs, rew, done, info - - -class NewStepEnv(gym.Env): - def __init__(self): - self.action_space = Discrete(2) - self.observation_space = Discrete(2) - - def step(self, action): - obs = self.observation_space.sample() - rew = 0 - terminated = False - truncated = False - info = {} - return obs, rew, terminated, truncated, info - - -@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv]) -@pytest.mark.parametrize("output_truncation_bool", [None, True]) -def test_step_compatibility_to_new_api(env, output_truncation_bool): - if output_truncation_bool is None: - env = StepAPICompatibility(env()) - else: - env = StepAPICompatibility(env(), output_truncation_bool) - step_returns = env.step(0) - _, _, terminated, truncated, _ = step_returns - assert isinstance(terminated, bool) - assert isinstance(truncated, bool) - - -@pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv]) -def test_step_compatibility_to_old_api(env): - env = StepAPICompatibility(env(), False) - step_returns = env.step(0) - assert len(step_returns) == 4 - _, _, done, _ = step_returns - assert isinstance(done, bool) - - -@pytest.mark.parametrize("vector_env", [SyncVectorEnv, AsyncVectorEnv]) -def test_vector_env_step_compatibility_to_old_api(vector_env): - num_envs = 2 - env = vector_env([NewStepEnv for _ in range(num_envs)]) - old_env = StepAPICompatibility(env, False) - - step_returns = old_env.step([0] * num_envs) - assert len(step_returns) == 4 - _, _, dones, _ = step_returns - assert isinstance(dones, np.ndarray) - for done in dones: - assert isinstance(done, np.bool_) - - -@pytest.mark.parametrize("apply_api_compatibility", [None, True, False]) -def test_step_compatibility_in_make(apply_api_compatibility): - gym.register("OldStepEnv-v0", entry_point=OldStepEnv) - - if apply_api_compatibility is not None: - env = gym.make( - "OldStepEnv-v0", - apply_api_compatibility=apply_api_compatibility, - disable_env_checker=True, - ) - else: - env = gym.make("OldStepEnv-v0", disable_env_checker=True) - - env.reset() - step_returns = env.step(0) - if apply_api_compatibility: - assert len(step_returns) == 5 - _, _, terminated, truncated, _ = step_returns - assert isinstance(terminated, bool) - assert isinstance(truncated, bool) - else: - assert len(step_returns) == 4 - _, _, done, _ = step_returns - assert isinstance(done, bool) - - gym.envs.registry.pop("OldStepEnv-v0") diff --git a/tests/experimental/wrappers/test_sticky_action.py b/tests/wrappers/test_sticky_action.py similarity index 81% rename from tests/experimental/wrappers/test_sticky_action.py rename to tests/wrappers/test_sticky_action.py index efa51a0eb..bdd4c4858 100644 --- a/tests/experimental/wrappers/test_sticky_action.py +++ b/tests/wrappers/test_sticky_action.py @@ -1,16 +1,16 @@ -"""Test suite for StickyActionV0.""" +"""Test suite for StickyAction wrapper.""" import numpy as np import pytest from gymnasium.error import InvalidProbability -from gymnasium.experimental.wrappers import StickyActionV0 -from tests.experimental.wrappers.utils import NUM_STEPS, record_action_as_obs_step +from gymnasium.wrappers import StickyAction from tests.testing_env import GenericTestEnv +from tests.wrappers.utils import NUM_STEPS, record_action_as_obs_step def test_sticky_action(): """Tests the sticky action wrapper.""" - env = StickyActionV0( + env = StickyAction( GenericTestEnv(step_func=record_action_as_obs_step), repeat_action_probability=0.5, ) @@ -30,6 +30,6 @@ def test_sticky_action(): def test_sticky_action_raise(repeat_action_probability): """Tests the stick action wrapper with probabilities that should raise an error.""" with pytest.raises(InvalidProbability): - StickyActionV0( + StickyAction( GenericTestEnv(), repeat_action_probability=repeat_action_probability ) diff --git a/tests/wrappers/test_time_aware_observation.py b/tests/wrappers/test_time_aware_observation.py index 6d2de6293..7eeaea2ef 100644 --- a/tests/wrappers/test_time_aware_observation.py +++ b/tests/wrappers/test_time_aware_observation.py @@ -1,12 +1,16 @@ +"""Test suite for TimeAwareObservation wrapper.""" +import numpy as np import pytest import gymnasium as gym from gymnasium import spaces +from gymnasium.spaces import Box, Dict, Tuple from gymnasium.wrappers import TimeAwareObservation +from tests.testing_env import GenericTestEnv @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) -def test_time_aware_observation(env_id): +def test_default(env_id): env = gym.make(env_id, disable_env_checker=True) wrapped_env = TimeAwareObservation(env) @@ -16,21 +20,92 @@ def test_time_aware_observation(env_id): obs, info = env.reset() wrapped_obs, wrapped_obs_info = wrapped_env.reset() - assert wrapped_env.t == 0.0 - assert wrapped_obs[-1] == 0.0 + assert wrapped_env.timesteps == 0.0 + assert wrapped_obs[-1] == 0.0, wrapped_obs assert wrapped_obs.shape[0] == obs.shape[0] + 1 wrapped_obs, _, _, _, _ = wrapped_env.step(env.action_space.sample()) - assert wrapped_env.t == 1.0 + assert wrapped_env.timesteps == 1.0 assert wrapped_obs[-1] == 1.0 assert wrapped_obs.shape[0] == obs.shape[0] + 1 wrapped_obs, _, _, _, _ = wrapped_env.step(env.action_space.sample()) - assert wrapped_env.t == 2.0 + assert wrapped_env.timesteps == 2.0 assert wrapped_obs[-1] == 2.0 assert wrapped_obs.shape[0] == obs.shape[0] + 1 wrapped_obs, wrapped_obs_info = wrapped_env.reset() - assert wrapped_env.t == 0.0 + assert wrapped_env.timesteps == 0.0 assert wrapped_obs[-1] == 0.0 assert wrapped_obs.shape[0] == obs.shape[0] + 1 + + +def test_no_flatten(): + """Test the TimeAwareObservation wrapper without flattening the space.""" + env = GenericTestEnv(observation_space=Box(0, 1)) + wrapped_env = TimeAwareObservation(env) + assert isinstance(wrapped_env.observation_space, Box) + reset_obs, _ = wrapped_env.reset() + step_obs, _, _, _, _ = wrapped_env.step(None) + assert reset_obs.shape == (2,) and step_obs.shape == (2,) + + assert reset_obs in wrapped_env.observation_space + assert step_obs in wrapped_env.observation_space + + +def test_with_flatten(): + """Test the flatten parameter for the TimeAwareObservation wrapper on three types of observation spaces.""" + env = GenericTestEnv(observation_space=Dict(arm_1=Box(0, 1), arm_2=Box(2, 3))) + wrapped_env = TimeAwareObservation(env, flatten=False) + assert isinstance(wrapped_env.observation_space, Dict) + reset_obs, _ = wrapped_env.reset() + step_obs, _, _, _, _ = wrapped_env.step(None) + assert "time" in reset_obs and "time" in step_obs, f"{reset_obs}, {step_obs}" + + assert reset_obs in wrapped_env.observation_space + assert step_obs in wrapped_env.observation_space + + env = GenericTestEnv(observation_space=Tuple((Box(0, 1), Box(2, 3)))) + wrapped_env = TimeAwareObservation(env, flatten=False) + assert isinstance(wrapped_env.observation_space, Tuple) + reset_obs, _ = wrapped_env.reset() + step_obs, _, _, _, _ = wrapped_env.step(None) + assert len(reset_obs) == 3 and len(step_obs) == 3 + + assert reset_obs in wrapped_env.observation_space + assert step_obs in wrapped_env.observation_space + + env = GenericTestEnv(observation_space=Box(0, 1)) + wrapped_env = TimeAwareObservation(env, flatten=False) + assert isinstance(wrapped_env.observation_space, Dict) + reset_obs, _ = wrapped_env.reset() + step_obs, _, _, _, _ = wrapped_env.step(None) + assert isinstance(reset_obs, dict) and isinstance(step_obs, dict) + assert "obs" in reset_obs and "obs" in step_obs + assert "time" in reset_obs and "time" in step_obs + + assert reset_obs in wrapped_env.observation_space + assert step_obs in wrapped_env.observation_space + + +def test_normalize_time(): + """Test the normalize time parameter for DelayObservation wrappers.""" + env = GenericTestEnv(observation_space=Box(0, 1)) + wrapped_env = TimeAwareObservation(env, flatten=False, normalize_time=False) + reset_obs, _ = wrapped_env.reset() + step_obs, _, _, _, _ = wrapped_env.step(None) + assert reset_obs["time"] == np.array([0], dtype=np.int32) and step_obs[ + "time" + ] == np.array([1], dtype=np.int32) + + assert reset_obs in wrapped_env.observation_space + assert step_obs in wrapped_env.observation_space + + env = GenericTestEnv(observation_space=Box(0, 1)) + wrapped_env = TimeAwareObservation(env, flatten=False, normalize_time=True) + reset_obs, _ = wrapped_env.reset() + step_obs, _, _, _, _ = wrapped_env.step(None) + assert reset_obs["time"] == 0.0 and step_obs["time"] == 0.01 + + assert reset_obs in wrapped_env.observation_space + assert step_obs in wrapped_env.observation_space diff --git a/tests/wrappers/test_time_limit.py b/tests/wrappers/test_time_limit.py index ca479a818..2d5261963 100644 --- a/tests/wrappers/test_time_limit.py +++ b/tests/wrappers/test_time_limit.py @@ -1,3 +1,5 @@ +"""Test suite for TimeLimit wrapper.""" + import pytest import gymnasium as gym diff --git a/tests/wrappers/test_transform_observation.py b/tests/wrappers/test_transform_observation.py deleted file mode 100644 index 0951ca66f..000000000 --- a/tests/wrappers/test_transform_observation.py +++ /dev/null @@ -1,36 +0,0 @@ -import numpy as np -import pytest - -import gymnasium as gym -from gymnasium.wrappers import TransformObservation - - -@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) -def test_transform_observation(env_id): - def affine_transform(x): - return 3 * x + 2 - - env = gym.make(env_id, disable_env_checker=True) - wrapped_env = TransformObservation( - gym.make(env_id, disable_env_checker=True), - lambda obs: affine_transform(obs), - ) - - obs, info = env.reset(seed=0) - wrapped_obs, wrapped_obs_info = wrapped_env.reset(seed=0) - assert np.allclose(wrapped_obs, affine_transform(obs)) - assert isinstance(wrapped_obs_info, dict) - - action = env.action_space.sample() - obs, reward, terminated, truncated, _ = env.step(action) - ( - wrapped_obs, - wrapped_reward, - wrapped_terminated, - wrapped_truncated, - _, - ) = wrapped_env.step(action) - assert np.allclose(wrapped_obs, affine_transform(obs)) - assert np.allclose(wrapped_reward, reward) - assert wrapped_terminated == terminated - assert wrapped_truncated == truncated diff --git a/tests/wrappers/test_transform_reward.py b/tests/wrappers/test_transform_reward.py deleted file mode 100644 index c2910f375..000000000 --- a/tests/wrappers/test_transform_reward.py +++ /dev/null @@ -1,63 +0,0 @@ -import numpy as np -import pytest - -import gymnasium as gym -from gymnasium.wrappers import TransformReward - - -@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) -def test_transform_reward(env_id): - # use case #1: scale - scales = [0.1, 200] - for scale in scales: - env = gym.make(env_id, disable_env_checker=True) - wrapped_env = TransformReward( - gym.make(env_id, disable_env_checker=True), lambda r: scale * r - ) - action = env.action_space.sample() - - env.reset(seed=0) - wrapped_env.reset(seed=0) - - _, reward, _, _, _ = env.step(action) - _, wrapped_reward, _, _, _ = wrapped_env.step(action) - - assert wrapped_reward == scale * reward - del env, wrapped_env - - # use case #2: clip - min_r = -0.0005 - max_r = 0.0002 - env = gym.make(env_id, disable_env_checker=True) - wrapped_env = TransformReward( - gym.make(env_id, disable_env_checker=True), - lambda r: np.clip(r, min_r, max_r), - ) - action = env.action_space.sample() - - env.reset(seed=0) - wrapped_env.reset(seed=0) - - _, reward, _, _, _ = env.step(action) - _, wrapped_reward, _, _, _ = wrapped_env.step(action) - - assert abs(wrapped_reward) < abs(reward) - assert wrapped_reward == -0.0005 or wrapped_reward == 0.0002 - del env, wrapped_env - - # use case #3: sign - env = gym.make(env_id, disable_env_checker=True) - wrapped_env = TransformReward( - gym.make(env_id, disable_env_checker=True), lambda r: np.sign(r) - ) - - env.reset(seed=0) - wrapped_env.reset(seed=0) - - for _ in range(1000): - action = env.action_space.sample() - _, wrapped_reward, terminated, truncated, _ = wrapped_env.step(action) - assert wrapped_reward in [-1.0, 0.0, 1.0] - if terminated or truncated: - break - del env, wrapped_env diff --git a/tests/wrappers/test_video_recorder.py b/tests/wrappers/test_video_recorder.py deleted file mode 100644 index e3e981ef6..000000000 --- a/tests/wrappers/test_video_recorder.py +++ /dev/null @@ -1,92 +0,0 @@ -import os -import re - -import pytest - -import gymnasium as gym -from gymnasium.wrappers.monitoring.video_recorder import VideoRecorder - - -class BrokenRecordableEnv(gym.Env): - metadata = {"render_modes": ["rgb_array_list"]} - - def __init__(self, render_mode="rgb_array_list"): - self.render_mode = render_mode - - def render(self): - pass - - -class UnrecordableEnv(gym.Env): - metadata = {"render_modes": [None]} - - def __init__(self, render_mode=None): - self.render_mode = render_mode - - def render(self): - pass - - -def test_record_simple(): - env = gym.make( - "CartPole-v1", render_mode="rgb_array_list", disable_env_checker=True - ) - rec = VideoRecorder(env) - env.reset() - rec.capture_frame() - - rec.close() - - assert not rec.broken - assert os.path.exists(rec.path) - f = open(rec.path) - assert os.fstat(f.fileno()).st_size > 100 - - -def test_no_frames(): - env = BrokenRecordableEnv() - rec = VideoRecorder(env) - rec.close() - assert rec.functional - assert not os.path.exists(rec.path) - - -def test_record_unrecordable_method(): - error_message = ( - "Render mode is None, which is incompatible with RecordVideo." - " Initialize your environment with a render_mode that returns an" - " image, such as rgb_array." - ) - with pytest.raises(ValueError, match=re.escape(error_message)): - env = UnrecordableEnv() - rec = VideoRecorder(env) - assert not rec.enabled - rec.close() - - -def test_record_breaking_render_method(): - with pytest.warns( - UserWarning, - match=re.escape( - "Env returned None on `render()`. Disabling further rendering for video recorder by marking as disabled:" - ), - ): - env = BrokenRecordableEnv() - rec = VideoRecorder(env) - rec.capture_frame() - rec.close() - assert rec.broken - assert not os.path.exists(rec.path) - - -def test_text_envs(): - env = gym.make( - "FrozenLake-v1", render_mode="rgb_array_list", disable_env_checker=True - ) - video = VideoRecorder(env) - try: - env.reset() - video.capture_frame() - video.close() - finally: - os.remove(video.path) diff --git a/tests/wrappers/utils.py b/tests/wrappers/utils.py index 56b495765..a1b1b99bb 100644 --- a/tests/wrappers/utils.py +++ b/tests/wrappers/utils.py @@ -1,9 +1,91 @@ +"""Utility functions for testing the wrappers.""" from __future__ import annotations import gymnasium as gym +from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS +from tests.testing_env import GenericTestEnv + + +SEED = 42 +ENV_ID = "CartPole-v1" +DISCRETE_ACTION = 0 +NUM_ENVS = 3 +NUM_STEPS = 20 + + +def record_obs_reset(self: gym.Env, seed=None, options: dict = None): + """Records and uses an observation passed through options.""" + return options["obs"], {"obs": options["obs"]} + + +def record_random_obs_reset(self: gym.Env, seed=None, options=None): + """Records random observation generated by the environment.""" + obs = self.observation_space.sample() + return obs, {"obs": obs} + + +def record_action_step(self: gym.Env, action): + """Records the actions passed to the environment.""" + return 0, 0, False, False, {"action": action} + + +def record_random_obs_step(self: gym.Env, action): + """Records the observation generated by the environment.""" + obs = self.observation_space.sample() + return obs, 0, False, False, {"obs": obs} + + +def record_action_as_obs_step(self: gym.Env, action): + """Uses the action as the observation.""" + return action, 0, False, False, {"obs": action} + + +def record_action_as_record_step(self: gym.Env, action): + """Uses the action as the reward.""" + return 0, action, False, False, {"reward": action} + + +def check_obs( + env: gym.Env, + wrapped_env: gym.Wrapper, + transformed_obs, + original_obs, + strict: bool = True, +): + """Checks that the original and transformed observations using the environment and wrapped environment. + + Args: + env: The base environment + wrapped_env: The wrapped environment + transformed_obs: The transformed observation by the wrapped environment + original_obs: The original observation by the base environment. + strict: If to check that the observations aren't contained in the other environment. + """ + assert ( + transformed_obs in wrapped_env.observation_space + ), f"{transformed_obs}, {wrapped_env.observation_space}" + assert ( + original_obs in env.observation_space + ), f"{original_obs}, {env.observation_space}" + + if strict: + assert ( + transformed_obs not in env.observation_space + ), f"{transformed_obs}, {env.observation_space}" + assert ( + original_obs not in wrapped_env.observation_space + ), f"{original_obs}, {wrapped_env.observation_space}" + + +TESTING_OBS_ENVS = [GenericTestEnv(observation_space=space) for space in TESTING_SPACES] +TESTING_OBS_ENVS_IDS = TESTING_SPACES_IDS + +TESTING_ACTION_ENVS = [GenericTestEnv(action_space=space) for space in TESTING_SPACES] +TESTING_ACTION_ENVS_IDS = TESTING_SPACES_IDS def has_wrapper(wrapped_env: gym.Env, wrapper_type: type[gym.Wrapper]) -> bool: + """Checks if the wrapper type is within the wrapped environment stack.""" while isinstance(wrapped_env, gym.Wrapper): if isinstance(wrapped_env, wrapper_type): return True diff --git a/tests/wrappers/vector/__init__.py b/tests/wrappers/vector/__init__.py new file mode 100644 index 000000000..ea142f581 --- /dev/null +++ b/tests/wrappers/vector/__init__.py @@ -0,0 +1 @@ +"""Test suite for `gymnasium.wrappers.vector`.""" diff --git a/tests/wrappers/test_vector_list_info.py b/tests/wrappers/vector/test_dict_info_to_list.py similarity index 74% rename from tests/wrappers/test_vector_list_info.py rename to tests/wrappers/vector/test_dict_info_to_list.py index 801f236af..cb09029a9 100644 --- a/tests/wrappers/test_vector_list_info.py +++ b/tests/wrappers/vector/test_dict_info_to_list.py @@ -1,8 +1,10 @@ +"""Test suite for DictInfoTolist wrapper.""" + import numpy as np import pytest import gymnasium as gym -from gymnasium.wrappers import RecordEpisodeStatistics, VectorListInfo +from gymnasium.wrappers.vector import DictInfoToList, RecordEpisodeStatistics ENV_ID = "CartPole-v1" @@ -13,17 +15,17 @@ SEED = 42 def test_usage_in_vector_env(): env = gym.make(ENV_ID, disable_env_checker=True) - vector_env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True) + vector_env = gym.make_vec(ENV_ID, num_envs=NUM_ENVS, vectorization_mode="sync") - VectorListInfo(vector_env) + DictInfoToList(vector_env) with pytest.raises(AssertionError): - VectorListInfo(env) + DictInfoToList(env) def test_info_to_list(): - env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True) - wrapped_env = VectorListInfo(env_to_wrap) + env_to_wrap = gym.make_vec(ENV_ID, num_envs=NUM_ENVS, vectorization_mode="sync") + wrapped_env = DictInfoToList(env_to_wrap) wrapped_env.action_space.seed(SEED) _, info = wrapped_env.reset(seed=SEED) assert isinstance(info, list) @@ -40,8 +42,8 @@ def test_info_to_list(): def test_info_to_list_statistics(): - env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True) - wrapped_env = VectorListInfo(RecordEpisodeStatistics(env_to_wrap)) + env_to_wrap = gym.make_vec(ENV_ID, num_envs=NUM_ENVS, vectorization_mode="sync") + wrapped_env = DictInfoToList(RecordEpisodeStatistics(env_to_wrap)) _, info = wrapped_env.reset(seed=SEED) wrapped_env.action_space.seed(SEED) assert isinstance(info, list) diff --git a/tests/wrappers/vector/test_normalize_observation.py b/tests/wrappers/vector/test_normalize_observation.py new file mode 100644 index 000000000..25ab2617f --- /dev/null +++ b/tests/wrappers/vector/test_normalize_observation.py @@ -0,0 +1,68 @@ +"""Test suite for vector NormalizeObservation wrapper..""" +import numpy as np + +from gymnasium import spaces, wrappers +from gymnasium.vector import SyncVectorEnv +from gymnasium.vector.utils import create_empty_array +from tests.testing_env import GenericTestEnv + + +def thunk(): + return GenericTestEnv( + observation_space=spaces.Box( + low=np.array([0, -10, -5], dtype=np.float32), + high=np.array([10, -5, 10], dtype=np.float32), + ) + ) + + +def test_against_wrapper( + n_envs=3, + n_steps=250, + mean_rtol=np.array([0.1, 0.4, 0.25]), + var_rtol=np.array([0.15, 0.15, 0.18]), +): + vec_env = SyncVectorEnv([thunk for _ in range(n_envs)]) + vec_env = wrappers.vector.NormalizeObservation(vec_env) + + vec_env.reset() + for _ in range(n_steps): + vec_env.step(vec_env.action_space.sample()) + + env = wrappers.Autoreset(thunk()) + env = wrappers.NormalizeObservation(env) + env.reset() + for _ in range(n_envs * n_steps): + env.step(env.action_space.sample()) + + assert np.allclose(env.obs_rms.mean, vec_env.obs_rms.mean, rtol=mean_rtol) + assert np.allclose(env.obs_rms.var, vec_env.obs_rms.var, rtol=var_rtol) + + +def test_update_running_mean(): + env = SyncVectorEnv([thunk for _ in range(2)]) + env = wrappers.vector.NormalizeObservation(env) + + # Default value is True + assert env.update_running_mean + + obs, _ = env.reset() + for _ in range(100): + env.step(env.action_space.sample()) + + # Disable + env.update_running_mean = False + rms_mean = np.copy(env.obs_rms.mean) + rms_var = np.copy(env.obs_rms.var) + + val_step = 25 + obs_buffer = create_empty_array(env.observation_space, val_step) + env.action_space.seed(123) + for i in range(val_step): + obs, _, _, _, _ = env.step(env.action_space.sample()) + obs_buffer[i] = obs + + assert np.all(rms_mean == env.obs_rms.mean) + assert np.all(rms_var == env.obs_rms.var) + assert np.allclose(np.mean(obs_buffer, axis=(0, 1)), 0, atol=0.5) + assert np.allclose(np.var(obs_buffer, axis=(0, 1)), 1, atol=0.5) diff --git a/tests/wrappers/vector/test_normalize_reward.py b/tests/wrappers/vector/test_normalize_reward.py new file mode 100644 index 000000000..88009a3bb --- /dev/null +++ b/tests/wrappers/vector/test_normalize_reward.py @@ -0,0 +1,70 @@ +"""Test suite for vector NormalizeReward wrapper.""" +from typing import Optional + +import numpy as np + +from gymnasium import wrappers +from gymnasium.core import ActType +from gymnasium.vector import SyncVectorEnv +from tests.testing_env import GenericTestEnv + + +def reset_func(self, seed: Optional[int] = None, options: Optional[dict] = None): + self.step_id = 0 + return self.observation_space.sample(), {} + + +def step_func(self, action: ActType): + self.step_id += 1 + terminated = self.step_id == 10 + return self.observation_space.sample(), float(terminated), terminated, False, {} + + +def thunk(): + return GenericTestEnv(step_func=step_func, reset_func=reset_func) + + +def test_functionality( + n_envs=3, + n_steps=100, +): + env = SyncVectorEnv([thunk for _ in range(n_envs)]) + env = wrappers.vector.NormalizeReward(env) + + env.reset() + for _ in range(n_steps): + action = env.action_space.sample() + env.step(action) + + env.reset() + forward_rets = [] + accumulated_rew = 0 + for _ in range(n_steps): + action = env.action_space.sample() + _, rew, ter, tru, _ = env.step(action) + dones = np.logical_or(ter, tru) + accumulated_rew = accumulated_rew * 0.9 * dones + rew + forward_rets.append(accumulated_rew) + + env.close() + + forward_rets = np.asarray(forward_rets) + assert np.allclose(np.std(forward_rets, axis=0), 1.33, atol=0.1) + + +def test_against_wrapper(n_envs=3, n_steps=100, rtol=0.01, atol=0): + vec_env = SyncVectorEnv([thunk for _ in range(n_envs)]) + vec_env = wrappers.vector.NormalizeReward(vec_env) + vec_env.reset() + for _ in range(n_steps): + action = vec_env.action_space.sample() + vec_env.step(action) + + env = wrappers.Autoreset(thunk()) + env = wrappers.NormalizeReward(env) + env.reset() + for _ in range(n_steps): + action = env.action_space.sample() + _, _, ter, tru, _ = env.step(action) + + assert np.allclose(env.return_rms.var, vec_env.return_rms.var, rtol=rtol, atol=atol) diff --git a/tests/experimental/wrappers/vector/test_vector_wrappers.py b/tests/wrappers/vector/test_vector_wrappers.py similarity index 63% rename from tests/experimental/wrappers/vector/test_vector_wrappers.py rename to tests/wrappers/vector/test_vector_wrappers.py index 2c308101d..ed6e85e97 100644 --- a/tests/experimental/wrappers/vector/test_vector_wrappers.py +++ b/tests/wrappers/vector/test_vector_wrappers.py @@ -1,6 +1,10 @@ """Tests that the vectorised wrappers operate identically in `VectorEnv(Wrapper)` and `VectorWrapper(VectorEnv)`. -The exception is the data converter wrappers (`JaxToTorch`, `JaxToNumpy` and `NumpyToJax`) +The exception is the data converter wrappers + * Data conversion wrappers - `JaxToTorch`, `JaxToNumpy` and `NumpyToJax` + * Normalizing wrappers - `NormalizeObservation` and `NormalizeReward` + * Different implementations - `LambdaObservation`, `LambdaReward` and `LambdaAction` + * Different random sources - `StickyAction` """ from __future__ import annotations @@ -10,10 +14,10 @@ import numpy as np import pytest import gymnasium as gym -from gymnasium.experimental import wrappers -from gymnasium.experimental.vector import VectorEnv +from gymnasium import wrappers from gymnasium.spaces import Box, Dict, Discrete from gymnasium.utils.env_checker import data_equivalence +from gymnasium.vector import VectorEnv from tests.testing_env import GenericTestEnv @@ -35,27 +39,24 @@ def custom_environments(): @pytest.mark.parametrize( "env_id, wrapper_name, kwargs", ( - ("CustomDictEnv-v0", "FilterObservationV0", {"filter_keys": ["a"]}), - ("CartPole-v1", "FlattenObservationV0", {}), - ("CarRacing-v2", "GrayscaleObservationV0", {}), - # ("CarRacing-v2", "ResizeObservationV0", {"shape": (35, 45)}), - ("CarRacing-v2", "ReshapeObservationV0", {"shape": (96, 48, 6)}), - ("CartPole-v1", "RescaleObservationV0", {"min_obs": 0, "max_obs": 1}), - ("CartPole-v1", "DtypeObservationV0", {"dtype": np.int32}), - # ("CartPole-v1", "PixelObservationV0", {}), - # ("CartPole-v1", "NormalizeObservationV0", {}), - # ("CartPole-v1", "TimeAwareObservationV0", {}), - # ("CartPole-v1", "FrameStackObservationV0", {}), - # ("CartPole-v1", "DelayObservationV0", {}), - ("MountainCarContinuous-v0", "ClipActionV0", {}), + ("CustomDictEnv-v0", "FilterObservation", {"filter_keys": ["a"]}), + ("CartPole-v1", "FlattenObservation", {}), + ("CarRacing-v2", "GrayscaleObservation", {}), + # ("CarRacing-v2", "ResizeObservation", {"shape": (35, 45)}), + ("CarRacing-v2", "ReshapeObservation", {"shape": (96, 48, 6)}), + ("CartPole-v1", "RescaleObservation", {"min_obs": 0, "max_obs": 1}), + ("CartPole-v1", "DtypeObservation", {"dtype": np.int32}), + # ("CartPole-v1", "RenderObservation", {}), + # ("CartPole-v1", "TimeAwareObservation", {}), + # ("CartPole-v1", "FrameStackObservation", {}), + # ("CartPole-v1", "DelayObservation", {}), + ("MountainCarContinuous-v0", "ClipAction", {}), ( "MountainCarContinuous-v0", - "RescaleActionV0", + "RescaleAction", {"min_action": 1, "max_action": 2}, ), - # ("CartPole-v1", "StickyActionV0", {}), - ("CartPole-v1", "ClipRewardV0", {"min_reward": 0.25, "max_reward": 0.75}), - # ("CartPole-v1", "NormalizeRewardV1", {}), + ("CartPole-v1", "ClipReward", {"min_reward": 0.25, "max_reward": 0.75}), ), ) def test_vector_wrapper_equivalence( @@ -112,12 +113,3 @@ def test_vector_wrapper_equivalence( wrapper_vector_env.close() vector_wrapper_env.close() - - -# ("CartPole-v1", "LambdaObservationV0", {"func": lambda obs: obs + 1}), -# ("CartPole-v1", "LambdaActionV0", {"func": lambda action: action + 1}), -# ("CartPole-v1", "LambdaRewardV0", {"func": lambda reward: reward + 1}), -# (vector.JaxToNumpyV0, {}, {}), -# (vector.JaxToTorchV0, {}, {}), -# (vector.NumpyToTorchV0, {}, {}), -# ("CartPole-v1", "RecordEpisodeStatisticsV0", {}), # for the time taken in info, this is not equivalent for two instances