diff --git a/gymnasium/core.py b/gymnasium/core.py index 96207e5e4..8225e4905 100644 --- a/gymnasium/core.py +++ b/gymnasium/core.py @@ -40,7 +40,7 @@ class Env(Generic[ObsType, ActType]): - :attr:`action_space` - The Space object corresponding to valid actions, all valid actions should be contained within the space. - :attr:`observation_space` - The Space object corresponding to valid observations, all valid observations should be contained within the space. - :attr:`spec` - An environment spec that contains the information used to initialize the environment from :meth:`gymnasium.make` - - :attr:`metadata` - The metadata of the environment, e.g., `{"render_modes": ["rgb_array", "human"], "render_fps": 30}`. For Jax or Torch, this can be indicated to users with `"jax"=True` or `"torch"=True`. + - :attr:`metadata` - The metadata of the environment, e.g. `{"render_modes": ["rgb_array", "human"], "render_fps": 30}`. For Jax or Torch, this can be indicated to users with `"jax"=True` or `"torch"=True`. - :attr:`np_random` - The random number generator for the environment. This is automatically assigned during ``super().reset(seed=seed)`` and when assessing :attr:`np_random`. @@ -50,7 +50,7 @@ class Env(Generic[ObsType, ActType]): To get reproducible sampling of actions, a seed can be set with ``env.action_space.seed(123)``. Note: - For strict type checking (e.g., mypy or pyright), :class:`Env` is a generic class with two parameterized types: ``ObsType`` and ``ActType``. + For strict type checking (e.g. mypy or pyright), :class:`Env` is a generic class with two parameterized types: ``ObsType`` and ``ActType``. The ``ObsType`` and ``ActType`` are the expected types of the observations and actions used in :meth:`reset` and :meth:`step`. The environment's :attr:`observation_space` and :attr:`action_space` should have type ``Space[ObsType]`` and ``Space[ActType]``, see a space's implementation to find its parameterized type. diff --git a/gymnasium/vector/utils/space_utils.py b/gymnasium/vector/utils/space_utils.py index c4d5ef68d..05c5be55d 100644 --- a/gymnasium/vector/utils/space_utils.py +++ b/gymnasium/vector/utils/space_utils.py @@ -9,6 +9,7 @@ from __future__ import annotations +import typing from copy import deepcopy from functools import singledispatch from typing import Any, Iterable, Iterator @@ -44,17 +45,17 @@ __all__ = [ @singledispatch def batch_space(space: Space[Any], n: int = 1) -> Space[Any]: - """Create a (batched) space, containing multiple copies of a single space. + """Batch spaces of size `n` optimized for neural networks. Args: - space: Space (e.g. the observation space) for a single environment in the vectorized environment. - n: Number of environments in the vectorized environment. + space: Space (e.g. the observation space for a single environment in the vectorized environment). + n: Number of spaces to batch by (e.g. the number of environments in a vectorized environment). Returns: - Space (e.g. the observation space) for a batch of environments in the vectorized environment. + Batched space of size `n`. Raises: - ValueError: Cannot batch space does not have a registered function. + ValueError: Cannot batch spaces that does not have a registered function. Example: @@ -147,8 +148,21 @@ def _batch_space_custom(space: Graph | Text | Sequence | OneOf, n: int = 1): @singledispatch -def batch_differing_spaces(spaces: list[Space]): - """Batch a Sequence of spaces that allows the subspaces to contain minor differences.""" +def batch_differing_spaces(spaces: typing.Sequence[Space]) -> Space: + """Batch a Sequence of spaces where subspaces to contain minor differences. + + Args: + spaces: A sequence of Spaces with minor differences (the same space type but different parameters). + + Returns: + A batched space + + Example: + >>> from gymnasium.spaces import Discrete + >>> spaces = [Discrete(3), Discrete(5), Discrete(4), Discrete(8)] + >>> batch_differing_spaces(spaces) + MultiDiscrete([3 5 4 8]) + """ assert len(spaces) > 0, "Expects a non-empty list of spaces" assert all( isinstance(space, type(spaces[0])) for space in spaces @@ -257,19 +271,12 @@ def _batch_spaces_undefined(spaces: list[Graph | Text | Sequence | OneOf]): @singledispatch -def iterate(space: Space[T_cov], items: Iterable[T_cov]) -> Iterator: +def iterate(space: Space[T_cov], items: T_cov) -> Iterator: """Iterate over the elements of a (batched) space. Args: - space: Observation space of a single environment in the vectorized environment. - items: Samples to be concatenated. - out: The output object. This object is a (possibly nested) numpy array. - - Returns: - The output object. This object is a (possibly nested) numpy array. - - Raises: - ValueError: Space is not an instance of :class:`gymnasium.Space` + space: (batched) space (e.g. `action_space` or `observation_space` from vectorized environment). + items: Batched samples to be iterated over (e.g. sample from the space). Example: >>> from gymnasium.spaces import Box, Dict @@ -353,15 +360,15 @@ def concatenate( """Concatenate multiple samples from space into a single object. Args: - space: Observation space of a single environment in the vectorized environment. - items: Samples to be concatenated. - out: The output object. This object is a (possibly nested) numpy array. + space: Space of each item (e.g. `single_action_space` from vectorized environment) + items: Samples to be concatenated (e.g. all sample should be an element of the `space`). + out: The output object (e.g. generated from `create_empty_array`) Returns: - The output object. This object is a (possibly nested) numpy array. + The output object, can be the same object `out`. Raises: - ValueError: Space + ValueError: Space is not a valid :class:`gymnasium.Space` instance Example: >>> from gymnasium.spaces import Box @@ -423,7 +430,7 @@ def _concatenate_custom(space: Space, items: Iterable, out: None) -> tuple[Any, def create_empty_array( space: Space, n: int = 1, fn: callable = np.zeros ) -> tuple[Any, ...] | dict[str, Any] | np.ndarray: - """Create an empty (possibly nested) (normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)``. + """Create an empty (possibly nested and normally numpy-based) array, used in conjunction with ``concatenate(..., out=array)``. In most cases, the array will be contained within the batched space, however, this is not guaranteed. diff --git a/gymnasium/wrappers/vector/vectorize_action.py b/gymnasium/wrappers/vector/vectorize_action.py index 3dc4a797a..f0f0b8e57 100644 --- a/gymnasium/wrappers/vector/vectorize_action.py +++ b/gymnasium/wrappers/vector/vectorize_action.py @@ -138,7 +138,7 @@ class VectorizeTransformAction(VectorActionWrapper): self.action_space = batch_space(self.single_action_space, self.num_envs) self.same_out = self.action_space == self.env.action_space - self.out = create_empty_array(self.single_action_space, self.num_envs) + self.out = create_empty_array(self.env.single_action_space, self.num_envs) def actions(self, actions: ActType) -> ActType: """Applies the wrapper to each of the action. @@ -151,7 +151,7 @@ class VectorizeTransformAction(VectorActionWrapper): """ if self.same_out: return concatenate( - self.single_action_space, + self.env.single_action_space, tuple( self.wrapper.func(action) for action in iterate(self.action_space, actions) @@ -161,10 +161,10 @@ class VectorizeTransformAction(VectorActionWrapper): else: return deepcopy( concatenate( - self.single_action_space, + self.env.single_action_space, tuple( self.wrapper.func(action) - for action in iterate(self.env.action_space, actions) + for action in iterate(self.action_space, actions) ), self.out, ) diff --git a/tests/functional/test_func_jax_env.py b/tests/functional/test_func_jax_env.py index d6d6ca435..e3eb225ec 100644 --- a/tests/functional/test_func_jax_env.py +++ b/tests/functional/test_func_jax_env.py @@ -4,6 +4,10 @@ import numpy as np import pytest +pytest.skip( + "Github CI is running forever for the tests in this file.", allow_module_level=True +) + jax = pytest.importorskip("jax") import jax.numpy as jnp # noqa: E402 import jax.random as jrng # noqa: E402 diff --git a/tests/wrappers/vector/test_vector_wrappers.py b/tests/wrappers/vector/test_vector_wrappers.py index ff7c9ce0a..311ecc6fe 100644 --- a/tests/wrappers/vector/test_vector_wrappers.py +++ b/tests/wrappers/vector/test_vector_wrappers.py @@ -25,7 +25,7 @@ from tests.testing_env import GenericTestEnv @pytest.fixture def custom_environments(): gym.register( - "CustomDictEnv-v0", + "DictObsEnv-v0", lambda: GenericTestEnv( observation_space=Dict({"a": Box(0, 1), "b": Discrete(5)}) ), @@ -33,14 +33,14 @@ def custom_environments(): yield - del gym.registry["CustomDictEnv-v0"] + del gym.registry["DictObsEnv-v0"] @pytest.mark.parametrize("num_envs", (1, 3)) @pytest.mark.parametrize( "env_id, wrapper_name, kwargs", ( - ("CustomDictEnv-v0", "FilterObservation", {"filter_keys": ["a"]}), + ("DictObsEnv-v0", "FilterObservation", {"filter_keys": ["a"]}), ("CartPole-v1", "FlattenObservation", {}), ("CarRacing-v3", "GrayscaleObservation", {}), ("CarRacing-v3", "ResizeObservation", {"shape": (35, 45)}), diff --git a/tests/wrappers/vector/test_vectorize_transform.py b/tests/wrappers/vector/test_vectorize_transform.py new file mode 100644 index 000000000..ec20e0fa0 --- /dev/null +++ b/tests/wrappers/vector/test_vectorize_transform.py @@ -0,0 +1,52 @@ +from functools import partial + +import numpy as np + +import gymnasium as gym +from gymnasium.vector import SyncVectorEnv +from tests.testing_env import GenericTestEnv + + +def test_vectorize_box_to_dict_action(): + def func(x): + return x["key"] + + envs = SyncVectorEnv([lambda: GenericTestEnv() for _ in range(2)]) + envs = gym.wrappers.vector.VectorizeTransformAction( + env=envs, + wrapper=gym.wrappers.TransformAction, + func=func, + action_space=gym.spaces.Dict( + {"key": gym.spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32)} + ), + ) + + obs, _ = envs.reset() + obs, _, _, _, _ = envs.step(envs.action_space.sample()) + envs.close() + + +def test_vectorize_dict_to_box_obs(): + wrappers = [ + partial( + gym.wrappers.TransformObservation, + func=lambda x: {"key1": x[0:1], "key2": x[1:]}, + observation_space=gym.spaces.Dict( + { + "key1": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(1,)), + "key2": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(3,)), + } + ), + ) + ] + envs = gym.make_vec( + "CartPole-v1", + num_envs=2, + vectorization_mode=gym.VectorizeMode.ASYNC, + wrappers=wrappers, + ) + obs, _ = envs.reset() + assert obs in envs.observation_space + obs, _, _, _, _ = envs.step(envs.action_space.sample()) + assert obs in envs.observation_space + envs.close()