Copy vector to experimental.vector (#317)

This commit is contained in:
Mark Towers
2023-02-18 21:08:08 +00:00
committed by GitHub
parent be58fffb46
commit a55b42dd1e
20 changed files with 1534 additions and 102 deletions

View File

@@ -51,7 +51,7 @@ repos:
rev: 6.1.1
hooks:
- id: pydocstyle
exclude: ^(gymnasium/envs/box2d)|(gymnasium/envs/classic_control)|(gymnasium/envs/mujoco)|(gymnasium/envs/toy_text)|(tests/experimental/vector)|(tests/envs)|(tests/spaces)|(tests/utils)|(tests/vector)|(tests/wrappers)|(docs/)
exclude: ^(gymnasium/envs/box2d)|(gymnasium/envs/classic_control)|(gymnasium/envs/mujoco)|(gymnasium/envs/toy_text)|(tests/envs)|(tests/spaces)|(tests/utils)|(tests/vector)|(tests/wrappers)|(docs/)
args:
- --source
- --explain

View File

@@ -1,13 +1,23 @@
"""Experimental vector env API."""
from gymnasium.experimental.vector import utils
from gymnasium.experimental.vector.async_vector_env import AsyncVectorEnv
from gymnasium.experimental.vector.sync_vector_env import SyncVectorEnv
from gymnasium.experimental.vector.vector_env import VectorEnv, VectorWrapper
from gymnasium.experimental.vector.vector_env import (
VectorActionWrapper,
VectorEnv,
VectorObservationWrapper,
VectorRewardWrapper,
VectorWrapper,
)
__all__ = [
# Vector
"VectorEnv",
"VectorWrapper",
"VectorObservationWrapper",
"VectorActionWrapper",
"VectorRewardWrapper",
"SyncVectorEnv",
"AsyncVectorEnv",
"utils",
]

View File

@@ -1,25 +1,26 @@
"""An async vector environment."""
from __future__ import annotations
import multiprocessing as mp
import sys
import time
from copy import deepcopy
from enum import Enum
from multiprocessing import Queue
from multiprocessing.connection import Connection
from multiprocessing.queues import Queue
from typing import Any, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Sequence
import numpy as np
from gymnasium import logger
from gymnasium.core import ObsType
from gymnasium.core import Env, ObsType
from gymnasium.error import (
AlreadyPendingCallError,
ClosedEnvironmentError,
CustomSpaceError,
NoAsyncCallError,
)
from gymnasium.experimental.vector.vector_env import VectorEnv
from gymnasium.vector.utils import (
from gymnasium.experimental.vector.utils import (
CloudpickleWrapper,
batch_space,
clear_mpi_env_vars,
@@ -30,6 +31,7 @@ from gymnasium.vector.utils import (
read_from_shared_memory,
write_to_shared_memory,
)
from gymnasium.experimental.vector.vector_env import VectorEnv
__all__ = ["AsyncVectorEnv"]
@@ -47,26 +49,25 @@ class AsyncVectorEnv(VectorEnv):
It uses ``multiprocessing`` processes, and pipes for communication.
Example::
Example:
>>> import gymnasium as gym
>>> env = gym.vector.AsyncVectorEnv([
... lambda: gym.make("Pendulum-v1", g=9.81),
... lambda: gym.make("Pendulum-v1", g=1.62)
... ])
>>> env.reset() # doctest: +SKIP
array([[-0.8286432 , 0.5597771 , 0.90249056],
[-0.85009176, 0.5266346 , 0.60007906]], dtype=float32)
>>> env.reset(seed=42)
(array([[-0.14995256, 0.9886932 , -0.12224312],
[ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32), {})
"""
def __init__(
self,
env_fns: Sequence[callable],
env_fns: Sequence[Callable[[], Env]],
shared_memory: bool = True,
copy: bool = True,
context: Optional[str] = None,
context: str | None = None,
daemon: bool = True,
worker: Optional[callable] = None,
worker: callable | None = None,
):
"""Vectorized environment that runs multiple environments in parallel.
@@ -124,7 +125,7 @@ class AsyncVectorEnv(VectorEnv):
self.observations = read_from_shared_memory(
self.single_observation_space, _obs_buffer, n=self.num_envs
)
except CustomSpaceError:
except CustomSpaceError as e:
raise ValueError(
"Using `shared_memory=True` in `AsyncVectorEnv` "
"is incompatible with non-standard Gymnasium observation spaces "
@@ -132,7 +133,7 @@ class AsyncVectorEnv(VectorEnv):
"only compatible with default Gymnasium spaces (e.g. `Box`, "
"`Tuple`, `Dict`) for batching. Set `shared_memory=False` "
"if you use custom observation spaces."
)
) from e
else:
_obs_buffer = None
self.observations = create_empty_array(
@@ -170,8 +171,8 @@ class AsyncVectorEnv(VectorEnv):
def reset_async(
self,
seed: Optional[Union[int, List[int]]] = None,
options: Optional[dict] = None,
seed: int | list[int] | None = None,
options: dict | None = None,
):
"""Send calls to the :obj:`reset` methods of the sub-environments.
@@ -213,8 +214,8 @@ class AsyncVectorEnv(VectorEnv):
def reset_wait(
self,
timeout: Optional[Union[int, float]] = None,
) -> Union[ObsType, Tuple[ObsType, List[dict]]]:
timeout: int | float | None = None,
) -> tuple[ObsType, list[dict]]:
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
Args:
@@ -260,8 +261,8 @@ class AsyncVectorEnv(VectorEnv):
def reset(
self,
*,
seed: Optional[Union[int, List[int]]] = None,
options: Optional[dict] = None,
seed: int | list[int] | None = None,
options: dict | None = None,
):
"""Reset all parallel environments and return a batch of initial observations and info.
@@ -301,8 +302,8 @@ class AsyncVectorEnv(VectorEnv):
self._state = AsyncState.WAITING_STEP
def step_wait(
self, timeout: Optional[Union[int, float]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, dict]:
self, timeout: int | float | None = None
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, dict]:
"""Wait for the calls to :obj:`step` in each sub-environment to finish.
Args:
@@ -336,6 +337,7 @@ class AsyncVectorEnv(VectorEnv):
obs, rew, terminated, truncated, info = result
successes.append(success)
if success:
observations_list.append(obs)
rewards.append(rew)
terminateds.append(terminated)
@@ -396,7 +398,7 @@ class AsyncVectorEnv(VectorEnv):
pipe.send(("_call", (name, args, kwargs)))
self._state = AsyncState.WAITING_CALL
def call_wait(self, timeout: Optional[Union[int, float]] = None) -> list:
def call_wait(self, timeout: int | float | None = None) -> list:
"""Calls all parent pipes and waits for the results.
Args:
@@ -429,7 +431,7 @@ class AsyncVectorEnv(VectorEnv):
return results
def call(self, name: str, *args, **kwargs) -> List[Any]:
def call(self, name: str, *args, **kwargs) -> list[Any]:
"""Call a method, or get a property, from each parallel environment.
Args:
@@ -454,7 +456,7 @@ class AsyncVectorEnv(VectorEnv):
"""
return self.call(name)
def set_attr(self, name: str, values: Union[list, tuple, object]):
def set_attr(self, name: str, values: list[Any] | tuple[Any] | object):
"""Sets an attribute of the sub-environments.
Args:
@@ -489,9 +491,7 @@ class AsyncVectorEnv(VectorEnv):
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
def close_extras(
self, timeout: Optional[Union[int, float]] = None, terminate: bool = False
):
def close_extras(self, timeout: int | float | None = None, terminate: bool = False):
"""Close the environments & clean up the extra resources (processes and pipes).
Args:
@@ -556,15 +556,13 @@ class AsyncVectorEnv(VectorEnv):
same_observation_spaces, same_action_spaces = zip(*results)
if not all(same_observation_spaces):
raise RuntimeError(
"Some environments have an observation space different from "
f"`{self.single_observation_space}`. In order to batch observations, "
"the observation spaces from all environments must be equal."
f"Some environments have an observation space different from `{self.single_observation_space}`. "
"In order to batch observations, the observation spaces from all environments must be equal."
)
if not all(same_action_spaces):
raise RuntimeError(
"Some environments have an action space different from "
f"`{self.single_action_space}`. In order to batch actions, the "
"action spaces from all environments must be equal."
f"Some environments have an action space different from `{self.single_action_space}`. "
"In order to batch actions, the action spaces from all environments must be equal."
)
def _assert_is_running(self):
@@ -573,7 +571,7 @@ class AsyncVectorEnv(VectorEnv):
f"Trying to operate on `{type(self).__name__}`, after a call to `close()`."
)
def _raise_if_errors(self, successes):
def _raise_if_errors(self, successes: list[bool]):
if all(successes):
return

View File

@@ -1,13 +1,19 @@
"""A synchronous vector environment."""
from __future__ import annotations
from copy import deepcopy
from typing import Any, Callable, Iterator, List, Optional, Union
from typing import Any, Callable, Iterator
import numpy as np
from gymnasium import Env
from gymnasium.experimental.vector.utils import (
batch_space,
concatenate,
create_empty_array,
iterate,
)
from gymnasium.experimental.vector.vector_env import VectorEnv
from gymnasium.vector.utils import concatenate, create_empty_array, iterate
from gymnasium.vector.utils.spaces import batch_space
__all__ = ["SyncVectorEnv"]
@@ -16,16 +22,15 @@ __all__ = ["SyncVectorEnv"]
class SyncVectorEnv(VectorEnv):
"""Vectorized environment that serially runs multiple environments.
Example::
Example:
>>> import gymnasium as gym
>>> env = gym.vector.SyncVectorEnv([
... lambda: gym.make("Pendulum-v1", g=9.81),
... lambda: gym.make("Pendulum-v1", g=1.62)
... ])
>>> env.reset() # doctest: +SKIP
array([[-0.8286432 , 0.5597771 , 0.90249056],
[-0.85009176, 0.5266346 , 0.60007906]], dtype=float32)
>>> env.reset(seed=42)
(array([[-0.14995256, 0.9886932 , -0.12224312],
[ 0.5760367 , 0.8174238 , -0.91244936]], dtype=float32), {})
"""
def __init__(
@@ -70,8 +75,8 @@ class SyncVectorEnv(VectorEnv):
def reset(
self,
seed: Optional[Union[int, List[int]]] = None,
options: Optional[dict] = None,
seed: int | list[int] | None = None,
options: dict | None = None,
):
"""Waits for the calls triggered by :meth:`reset_async` to finish and returns the results.
@@ -179,7 +184,7 @@ class SyncVectorEnv(VectorEnv):
"""
return self.call(name)
def set_attr(self, name: str, values: Union[list, tuple, Any]):
def set_attr(self, name: str, values: list | tuple | Any):
"""Sets an attribute of the sub-environments.
Args:

View File

@@ -0,0 +1,30 @@
"""Module for gymnasium experimental vector utility functions."""
from gymnasium.experimental.vector.utils.misc import (
CloudpickleWrapper,
clear_mpi_env_vars,
)
from gymnasium.experimental.vector.utils.shared_memory import (
create_shared_memory,
read_from_shared_memory,
write_to_shared_memory,
)
from gymnasium.experimental.vector.utils.space_utils import (
batch_space,
concatenate,
create_empty_array,
iterate,
)
__all__ = [
"batch_space",
"iterate",
"concatenate",
"create_empty_array",
"read_from_shared_memory",
"create_shared_memory",
"write_to_shared_memory",
"CloudpickleWrapper",
"clear_mpi_env_vars",
]

View File

@@ -0,0 +1,61 @@
"""Miscellaneous utilities."""
from __future__ import annotations
import contextlib
import os
from collections.abc import Callable
from gymnasium.core import Env
__all__ = ["CloudpickleWrapper", "clear_mpi_env_vars"]
class CloudpickleWrapper:
"""Wrapper that uses cloudpickle to pickle and unpickle the result."""
def __init__(self, fn: Callable[[], Env]):
"""Cloudpickle wrapper for a function."""
self.fn = fn
def __getstate__(self):
"""Get the state using `cloudpickle.dumps(self.fn)`."""
import cloudpickle
return cloudpickle.dumps(self.fn)
def __setstate__(self, ob):
"""Sets the state with obs."""
import pickle
self.fn = pickle.loads(ob)
def __call__(self):
"""Calls the function `self.fn` with no arguments."""
return self.fn()
@contextlib.contextmanager
def clear_mpi_env_vars():
"""Clears the MPI of environment variables.
`from mpi4py import MPI` will call `MPI_Init` by default.
If the child process has MPI environment variables, MPI will think that the child process
is an MPI process just like the parent and do bad things such as hang.
This context manager is a hacky way to clear those environment variables
temporarily such as when we are starting multiprocessing Processes.
Yields:
Yields for the context manager
"""
removed_environment = {}
for k, v in list(os.environ.items()):
for prefix in ["OMPI_", "PMI_"]:
if k.startswith(prefix):
removed_environment[k] = v
del os.environ[k]
try:
yield
finally:
os.environ.update(removed_environment)

View File

@@ -0,0 +1,181 @@
"""Utility functions for vector environments to share memory between processes."""
from __future__ import annotations
import multiprocessing as mp
from collections import OrderedDict
from ctypes import c_bool
from functools import singledispatch
import numpy as np
from gymnasium.error import CustomSpaceError
from gymnasium.spaces import (
Box,
Dict,
Discrete,
MultiBinary,
MultiDiscrete,
Space,
Tuple,
)
__all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_memory"]
@singledispatch
def create_shared_memory(space: Space, n: int = 1, ctx=mp) -> dict | tuple | mp.Array:
"""Create a shared memory object, to be shared across processes.
This eventually contains the observations from the vectorized environment.
Args:
space: Observation space of a single environment in the vectorized environment.
n: Number of environments in the vectorized environment (i.e. the number of processes).
ctx: The multiprocess module
Returns:
shared_memory for the shared object across processes.
Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
"""
raise CustomSpaceError(
f"Cannot create a shared memory for space with type `{type(space)}`. "
"`create_shared_memory` only supports by default built-in Gymnasium spaces, register function for custom space."
)
@create_shared_memory.register(Box)
@create_shared_memory.register(Discrete)
@create_shared_memory.register(MultiDiscrete)
@create_shared_memory.register(MultiBinary)
def _create_base_shared_memory(space, n: int = 1, ctx=mp):
dtype = space.dtype.char
if dtype in "?":
dtype = c_bool
return ctx.Array(dtype, n * int(np.prod(space.shape)))
@create_shared_memory.register(Tuple)
def _create_tuple_shared_memory(space, n: int = 1, ctx=mp):
return tuple(
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
)
@create_shared_memory.register(Dict)
def _create_dict_shared_memory(space, n=1, ctx=mp):
return OrderedDict(
[
(key, create_shared_memory(subspace, n=n, ctx=ctx))
for (key, subspace) in space.spaces.items()
]
)
@singledispatch
def read_from_shared_memory(
space: Space, shared_memory: dict | tuple | mp.Array, n: int = 1
) -> dict | tuple | np.ndarray:
"""Read the batch of observations from shared memory as a numpy array.
..notes::
The numpy array objects returned by `read_from_shared_memory` shares the
memory of `shared_memory`. Any changes to `shared_memory` are forwarded
to `observations`, and vice-versa. To avoid any side-effect, use `np.copy`.
Args:
space: Observation space of a single environment in the vectorized environment.
shared_memory: Shared object across processes. This contains the observations from the vectorized environment.
This object is created with `create_shared_memory`.
n: Number of environments in the vectorized environment (i.e. the number of processes).
Returns:
Batch of observations as a (possibly nested) numpy array.
Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
"""
raise CustomSpaceError(
f"Cannot read from a shared memory for space with type `{type(space)}`. "
"`read_from_shared_memory` only supports by default built-in Gymnasium spaces, register function for custom space."
)
@read_from_shared_memory.register(Box)
@read_from_shared_memory.register(Discrete)
@read_from_shared_memory.register(MultiDiscrete)
@read_from_shared_memory.register(MultiBinary)
def _read_base_from_shared_memory(space, shared_memory, n: int = 1):
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape(
(n,) + space.shape
)
@read_from_shared_memory.register(Tuple)
def _read_tuple_from_shared_memory(space, shared_memory, n: int = 1):
return tuple(
read_from_shared_memory(subspace, memory, n=n)
for (memory, subspace) in zip(shared_memory, space.spaces)
)
@read_from_shared_memory.register(Dict)
def _read_dict_from_shared_memory(space, shared_memory, n: int = 1):
return OrderedDict(
[
(key, read_from_shared_memory(subspace, shared_memory[key], n=n))
for (key, subspace) in space.spaces.items()
]
)
@singledispatch
def write_to_shared_memory(
space: Space,
index: int,
value: np.ndarray,
shared_memory: dict | tuple | mp.Array,
):
"""Write the observation of a single environment into shared memory.
Args:
space: Observation space of a single environment in the vectorized environment.
index: Index of the environment (must be in `[0, num_envs)`).
value: Observation of the single environment to write to shared memory.
shared_memory: Shared object across processes. This contains the observations from the vectorized environment.
This object is created with `create_shared_memory`.
Raises:
CustomSpaceError: Space is not a valid :class:`gymnasium.Space` instance
"""
raise CustomSpaceError(
f"Cannot write to a shared memory for space with type `{type(space)}`. "
"`write_to_shared_memory` only supports by default built-in Gymnasium spaces, register function for custom space."
)
@write_to_shared_memory.register(Box)
@write_to_shared_memory.register(Discrete)
@write_to_shared_memory.register(MultiDiscrete)
@write_to_shared_memory.register(MultiBinary)
def _write_base_to_shared_memory(space, index, value, shared_memory):
size = int(np.prod(space.shape))
destination = np.frombuffer(shared_memory.get_obj(), dtype=space.dtype)
np.copyto(
destination[index * size : (index + 1) * size],
np.asarray(value, dtype=space.dtype).flatten(),
)
@write_to_shared_memory.register(Tuple)
def _write_tuple_to_shared_memory(space, index, values, shared_memory):
for value, memory, subspace in zip(values, shared_memory, space.spaces):
write_to_shared_memory(subspace, index, value, memory)
@write_to_shared_memory.register(Dict)
def _write_dict_to_shared_memory(space, index, values, shared_memory):
for key, subspace in space.spaces.items():
write_to_shared_memory(subspace, index, values[key], shared_memory[key])

View File

@@ -0,0 +1,347 @@
"""Utility functions for gymnasium spaces: `batch_space` and `iterator`."""
from __future__ import annotations
from collections import OrderedDict
from copy import deepcopy
from functools import singledispatch
from typing import Callable, Iterable, Iterator
import numpy as np
from gymnasium.error import CustomSpaceError
from gymnasium.spaces import (
Box,
Dict,
Discrete,
MultiBinary,
MultiDiscrete,
Space,
Tuple,
)
__all__ = ["batch_space", "iterate", "concatenate", "create_empty_array"]
@singledispatch
def batch_space(space: Space, n: int = 1) -> Space:
"""Create a (batched) space, containing multiple copies of a single space.
Args:
space: Space (e.g. the observation space) for a single environment in the vectorized environment.
n: Number of environments in the vectorized environment.
Returns:
Space (e.g. the observation space) for a batch of environments in the vectorized environment.
Raises:
ValueError: Cannot batch space that is not a valid :class:`gym.Space` instance
Example:
>>> from gymnasium.spaces import Box, Dict
>>> import numpy as np
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)
... })
>>> batch_space(space, n=5)
Dict('position': Box(0.0, 1.0, (5, 3), float32), 'velocity': Box(0.0, 1.0, (5, 2), float32))
"""
raise ValueError(
f"Cannot batch space with type `{type(space)}`. The space must be a valid `gymnasium.Space` instance."
)
@batch_space.register(Box)
def _batch_space_box(space, n=1):
repeats = tuple([n] + [1] * space.low.ndim)
low, high = np.tile(space.low, repeats), np.tile(space.high, repeats)
return Box(low=low, high=high, dtype=space.dtype, seed=deepcopy(space.np_random))
@batch_space.register(Discrete)
def _batch_space_discrete(space, n=1):
if space.start == 0:
return MultiDiscrete(
np.full((n,), space.n, dtype=space.dtype),
dtype=space.dtype,
seed=deepcopy(space.np_random),
)
else:
return Box(
low=space.start,
high=space.start + space.n - 1,
shape=(n,),
dtype=space.dtype,
seed=deepcopy(space.np_random),
)
@batch_space.register(MultiDiscrete)
def _batch_space_multidiscrete(space, n=1):
repeats = tuple([n] + [1] * space.nvec.ndim)
high = np.tile(space.nvec, repeats) - 1
return Box(
low=np.zeros_like(high),
high=high,
dtype=space.dtype,
seed=deepcopy(space.np_random),
)
@batch_space.register(MultiBinary)
def _batch_space_multibinary(space, n=1):
return Box(
low=0,
high=1,
shape=(n,) + space.shape,
dtype=space.dtype,
seed=deepcopy(space.np_random),
)
@batch_space.register(Tuple)
def _batch_space_tuple(space, n=1):
return Tuple(
tuple(batch_space(subspace, n=n) for subspace in space.spaces),
seed=deepcopy(space.np_random),
)
@batch_space.register(Dict)
def _batch_space_dict(space, n=1):
return Dict(
OrderedDict(
[
(key, batch_space(subspace, n=n))
for (key, subspace) in space.spaces.items()
]
),
seed=deepcopy(space.np_random),
)
@batch_space.register(Space)
def _batch_space_custom(space, n=1):
# Without deepcopy, then the space.np_random is batched_space.spaces[0].np_random
# Which is an issue if you are sampling actions of both the original space and the batched space
batched_space = Tuple(
tuple(deepcopy(space) for _ in range(n)), seed=deepcopy(space.np_random)
)
new_seeds = list(map(int, batched_space.np_random.integers(0, 1e8, n)))
batched_space.seed(new_seeds)
return batched_space
@singledispatch
def iterate(space: Space, items) -> Iterator:
"""Iterate over the elements of a (batched) space.
Args:
space: Space to which `items` belong to.
items: Items to be iterated over.
Returns:
Iterator over the elements in `items`.
Raises:
ValueError: Space is not an instance of :class:`gym.Space`
Example:
>>> from gymnasium.spaces import Box, Dict
>>> import numpy as np
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(2, 3), seed=42, dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2, 2), seed=42, dtype=np.float32)})
>>> items = space.sample()
>>> it = iterate(space, items)
>>> next(it)
OrderedDict([('position', array([0.77395606, 0.43887845, 0.85859793], dtype=float32)), ('velocity', array([0.77395606, 0.43887845], dtype=float32))])
>>> next(it)
OrderedDict([('position', array([0.697368 , 0.09417735, 0.97562236], dtype=float32)), ('velocity', array([0.85859793, 0.697368 ], dtype=float32))])
>>> next(it)
Traceback (most recent call last):
...
StopIteration
"""
raise ValueError(
f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance."
)
@iterate.register(Discrete)
def _iterate_discrete(space, items):
raise TypeError("Unable to iterate over a space of type `Discrete`.")
@iterate.register(Box)
@iterate.register(MultiDiscrete)
@iterate.register(MultiBinary)
def _iterate_base(space, items):
try:
return iter(items)
except TypeError as e:
raise TypeError(
f"Unable to iterate over the following elements: {items}"
) from e
@iterate.register(Tuple)
def _iterate_tuple(space, items):
# If this is a tuple of custom subspaces only, then simply iterate over items
if all(
isinstance(subspace, Space)
and (not isinstance(subspace, (Box, Discrete, MultiDiscrete, Tuple, Dict)))
for subspace in space.spaces
):
return iter(items)
return zip(
*[iterate(subspace, items[i]) for i, subspace in enumerate(space.spaces)]
)
@iterate.register(Dict)
def _iterate_dict(space, items):
keys, values = zip(
*[
(key, iterate(subspace, items[key]))
for key, subspace in space.spaces.items()
]
)
for item in zip(*values):
yield OrderedDict([(key, value) for (key, value) in zip(keys, item)])
@iterate.register(Space)
def _iterate_custom(space, items):
raise CustomSpaceError(
f"Unable to iterate over {items}, since {space} "
"is a custom `gymnasium.Space` instance (i.e. not one of "
"`Box`, `Dict`, etc...)."
)
@singledispatch
def concatenate(
space: Space, items: Iterable, out: tuple | dict | np.ndarray
) -> tuple | dict | np.ndarray:
"""Concatenate multiple samples from space into a single object.
Args:
space: Observation space of a single environment in the vectorized environment.
items: Samples to be concatenated.
out: The output object. This object is a (possibly nested) numpy array.
Returns:
The output object. This object is a (possibly nested) numpy array.
Raises:
ValueError: Space is not a valid :class:`gym.Space` instance
Example:
>>> from gymnasium.spaces import Box
>>> import numpy as np
>>> space = Box(low=0, high=1, shape=(3,), seed=42, dtype=np.float32)
>>> out = np.zeros((2, 3), dtype=np.float32)
>>> items = [space.sample() for _ in range(2)]
>>> concatenate(space, items, out)
array([[0.77395606, 0.43887845, 0.85859793],
[0.697368 , 0.09417735, 0.97562236]], dtype=float32)
"""
raise ValueError(
f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance."
)
@concatenate.register(Box)
@concatenate.register(Discrete)
@concatenate.register(MultiDiscrete)
@concatenate.register(MultiBinary)
def _concatenate_base(space, items, out):
return np.stack(items, axis=0, out=out)
@concatenate.register(Tuple)
def _concatenate_tuple(space, items, out):
return tuple(
concatenate(subspace, [item[i] for item in items], out[i])
for (i, subspace) in enumerate(space.spaces)
)
@concatenate.register(Dict)
def _concatenate_dict(space, items, out):
return OrderedDict(
[
(key, concatenate(subspace, [item[key] for item in items], out[key]))
for (key, subspace) in space.spaces.items()
]
)
@concatenate.register(Space)
def _concatenate_custom(space, items, out):
return tuple(items)
@singledispatch
def create_empty_array(
space: Space, n: int = 1, fn: Callable[..., np.ndarray] = np.zeros
) -> tuple | dict | np.ndarray:
"""Create an empty (possibly nested) numpy array.
Args:
space: Observation space of a single environment in the vectorized environment.
n: Number of environments in the vectorized environment. If `None`, creates an empty sample from `space`.
fn: Function to apply when creating the empty numpy array. Examples of such functions are `np.empty` or `np.zeros`.
Returns:
The output object. This object is a (possibly nested) numpy array.
Raises:
ValueError: Space is not a valid :class:`gym.Space` instance
Example:
>>> from gymnasium.spaces import Box, Dict
>>> import numpy as np
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
>>> create_empty_array(space, n=2, fn=np.zeros)
OrderedDict([('position', array([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)), ('velocity', array([[0., 0.],
[0., 0.]], dtype=float32))])
"""
raise ValueError(
f"Space of type `{type(space)}` is not a valid `gymnasium.Space` instance."
)
@create_empty_array.register(Box)
@create_empty_array.register(Discrete)
@create_empty_array.register(MultiDiscrete)
@create_empty_array.register(MultiBinary)
def _create_empty_array_base(space, n=1, fn=np.zeros):
shape = space.shape if (n is None) else (n,) + space.shape
return fn(shape, dtype=space.dtype)
@create_empty_array.register(Tuple)
def _create_empty_array_tuple(space, n=1, fn=np.zeros):
return tuple(create_empty_array(subspace, n=n, fn=fn) for subspace in space.spaces)
@create_empty_array.register(Dict)
def _create_empty_array_dict(space, n=1, fn=np.zeros):
return OrderedDict(
[
(key, create_empty_array(subspace, n=n, fn=fn))
for (key, subspace) in space.spaces.items()
]
)
@create_empty_array.register(Space)
def _create_empty_array_custom(space, n=1, fn=np.zeros):
return None

View File

@@ -1,5 +1,7 @@
"""Base class for vectorized environments."""
from typing import TYPE_CHECKING, Generic, List, Optional, Tuple, TypeVar, Union
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Generic, TypeVar
import numpy as np
@@ -11,11 +13,19 @@ from gymnasium.utils import seeding
if TYPE_CHECKING:
from gymnasium.envs.registration import EnvSpec
__all__ = ["VectorEnv", "VectorWrapper"]
ArrayType = TypeVar("ArrayType")
__all__ = [
"VectorEnv",
"VectorWrapper",
"VectorObservationWrapper",
"VectorActionWrapper",
"VectorRewardWrapper",
"ArrayType",
]
class VectorEnv(Generic[ObsType, ActType, ArrayType]):
"""Base class for vectorized environments to run multiple independent copies of the same environment in parallel.
@@ -53,29 +63,23 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
In other words, a vector of multiple different environments is not supported.
"""
spec: "EnvSpec"
spec: EnvSpec
observation_space: gym.Space
action_space: gym.Space
num_envs: int
_np_random: Optional[np.random.Generator] = None
closed = False
def __init__(self, **kwargs):
"""Base class for vectorized environments.
Args:
num_envs: Number of environments in the vectorized environment.
"""
self.closed = False
_np_random: np.random.Generator | None = None
def reset(
self,
*,
seed: Optional[Union[int, List[int]]] = None,
options: Optional[dict] = None,
) -> Tuple[ObsType, dict]: # type: ignore
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
) -> tuple[ObsType, dict[str, Any]]: # type: ignore
"""Reset all parallel environments and return a batch of initial observations and info.
Args:
@@ -85,13 +89,21 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
Returns:
A batch of observations and info from the vectorized environment.
Example:
>>> import gymnasium as gym
>>> envs = gym.vector.make("CartPole-v1", num_envs=3)
>>> envs.reset(seed=42)
(array([[ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ],
[ 0.01522993, -0.04562247, -0.04799704, 0.03392126],
[-0.03774345, -0.02418869, -0.00942293, 0.0469184 ]],
dtype=float32), {})
"""
if seed is not None:
self._np_random, seed = seeding.np_random(seed)
def step(
self, actions: ActType
) -> Tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
) -> tuple[ObsType, ArrayType, ArrayType, ArrayType, dict]:
"""Take an action for each parallel environment.
Args:
@@ -104,6 +116,27 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
As the vector environments autoreset for a terminating and truncating sub-environments,
the returned observation and info is not the final step's observation or info which is instead stored in
info as `"final_observation"` and `"final_info"`.
Example:
>>> import gymnasium as gym
>>> import numpy as np
>>> envs = gym.vector.make("CartPole-v1", num_envs=3)
>>> _ = envs.reset(seed=42)
>>> actions = np.array([1, 0, 1])
>>> observations, rewards, termination, truncation, infos = envs.step(actions)
>>> observations
array([[ 0.02727336, 0.18847767, 0.03625453, -0.26141977],
[ 0.01431748, -0.24002443, -0.04731862, 0.3110827 ],
[-0.03822722, 0.1710671 , -0.00848456, -0.2487226 ]],
dtype=float32)
>>> rewards
array([1., 1., 1.])
>>> termination
array([False, False, False])
>>> termination
array([False, False, False])
>>> infos
{}
"""
pass
@@ -181,7 +214,7 @@ class VectorEnv(Generic[ObsType, ActType, ArrayType]):
infos[k], infos[f"_{k}"] = info_array, array_mask
return infos
def _init_info_arrays(self, dtype: type) -> Tuple[np.ndarray, np.ndarray]:
def _init_info_arrays(self, dtype: type) -> tuple[np.ndarray, np.ndarray]:
"""Initialize the info array.
Initialize the info array. If the dtype is numeric
@@ -285,10 +318,12 @@ class VectorObservationWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the observation. Equivalent to :class:`gym.ObservationWrapper` for vectorized environments."""
def reset(self, **kwargs):
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
observation = self.env.reset(**kwargs)
return self.observation(observation)
def step(self, actions):
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
observation, reward, termination, truncation, info = self.env.step(actions)
return (
self.observation(observation),
@@ -315,6 +350,7 @@ class VectorActionWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the actions. Equivalent of :class:`~gym.ActionWrapper` for vectorized environments."""
def step(self, actions: ActType):
"""Steps through the environment using a modified action by :meth:`action`."""
return self.env.step(self.action(actions))
def actions(self, actions: ActType) -> ActType:
@@ -333,6 +369,7 @@ class VectorRewardWrapper(VectorWrapper):
"""Wraps the vectorized environment to allow a modular transformation of the reward. Equivalent of :class:`~gym.RewardWrapper` for vectorized environments."""
def step(self, actions):
"""Steps through the environment returning a reward modified by :meth:`reward`."""
observation, reward, termination, truncation, info = self.env.step(actions)
return observation, self.reward(reward), termination, truncation, info

View File

@@ -1,3 +1,5 @@
"""Testing of the `gym.make_vec` function."""
import pytest
import gymnasium as gym
@@ -6,7 +8,8 @@ from gymnasium.wrappers import TimeLimit, TransformObservation
from tests.wrappers.utils import has_wrapper
def test_vector_make_id():
def test_make_vec_env_id():
"""Ensure that the `gym.make_vec` creates the right environment."""
env = gym.make_vec("CartPole-v1")
assert isinstance(env, AsyncVectorEnv)
assert env.num_envs == 1
@@ -14,13 +17,15 @@ def test_vector_make_id():
@pytest.mark.parametrize("num_envs", [1, 3, 10])
def test_vector_make_num_envs(num_envs):
def test_make_vec_num_envs(num_envs):
"""Test that the `gym.make_vec` num_envs parameter works."""
env = gym.make_vec("CartPole-v1", num_envs=num_envs)
assert env.num_envs == num_envs
env.close()
def test_vector_make_asynchronous():
def test_make_vec_vectorization_mode():
"""Tests the `gym.make_vec` vectorization mode works."""
env = gym.make_vec("CartPole-v1", vectorization_mode="async")
assert isinstance(env, AsyncVectorEnv)
env.close()
@@ -30,7 +35,8 @@ def test_vector_make_asynchronous():
env.close()
def test_vector_make_wrappers():
def test_make_vec_wrappers():
"""Tests that the `gym.make_vec` wrappers parameter works."""
env = gym.make_vec("CartPole-v1", num_envs=2, vectorization_mode="sync")
assert isinstance(env, SyncVectorEnv)
assert len(env.envs) == 2

View File

@@ -0,0 +1 @@
"""Testing for `gymnasium.experimental.vector`."""

View File

@@ -1,3 +1,5 @@
"""Test the `SyncVectorEnv` implementation."""
import re
from multiprocessing import TimeoutError
@@ -11,7 +13,7 @@ from gymnasium.error import (
)
from gymnasium.experimental.vector import AsyncVectorEnv
from gymnasium.spaces import Box, Discrete, MultiDiscrete, Tuple
from tests.vector.utils import (
from tests.experimental.vector.testing_utils import (
CustomSpace,
make_custom_space_env,
make_env,
@@ -21,6 +23,7 @@ from tests.vector.utils import (
@pytest.mark.parametrize("shared_memory", [True, False])
def test_create_async_vector_env(shared_memory):
"""Test creating an async vector environment with or without shared memory."""
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
@@ -30,6 +33,7 @@ def test_create_async_vector_env(shared_memory):
@pytest.mark.parametrize("shared_memory", [True, False])
def test_reset_async_vector_env(shared_memory):
"""Test the reset of an sync vector environment with or without shared memory."""
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
@@ -61,6 +65,7 @@ def test_reset_async_vector_env(shared_memory):
@pytest.mark.parametrize("shared_memory", [True, False])
@pytest.mark.parametrize("use_single_action_space", [True, False])
def test_step_async_vector_env(shared_memory, use_single_action_space):
"""Test the step async vector environment with and without shared memory."""
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
@@ -73,7 +78,7 @@ def test_step_async_vector_env(shared_memory, use_single_action_space):
actions = [env.single_action_space.sample() for _ in range(8)]
else:
actions = env.action_space.sample()
observations, rewards, terminateds, truncateds, _ = env.step(actions)
observations, rewards, terminations, truncations, _ = env.step(actions)
env.close()
@@ -88,19 +93,20 @@ def test_step_async_vector_env(shared_memory, use_single_action_space):
assert rewards.ndim == 1
assert rewards.size == 8
assert isinstance(terminateds, np.ndarray)
assert terminateds.dtype == np.bool_
assert terminateds.ndim == 1
assert terminateds.size == 8
assert isinstance(terminations, np.ndarray)
assert terminations.dtype == np.bool_
assert terminations.ndim == 1
assert terminations.size == 8
assert isinstance(truncateds, np.ndarray)
assert truncateds.dtype == np.bool_
assert truncateds.ndim == 1
assert truncateds.size == 8
assert isinstance(truncations, np.ndarray)
assert truncations.dtype == np.bool_
assert truncations.ndim == 1
assert truncations.size == 8
@pytest.mark.parametrize("shared_memory", [True, False])
def test_call_async_vector_env(shared_memory):
"""Test call with async vector environment."""
env_fns = [
make_env("CartPole-v1", i, render_mode="rgb_array_list") for i in range(4)
]
@@ -127,6 +133,7 @@ def test_call_async_vector_env(shared_memory):
@pytest.mark.parametrize("shared_memory", [True, False])
def test_set_attr_async_vector_env(shared_memory):
"""Test `set_attr_` for async vector environment with or without shared memory."""
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
@@ -139,6 +146,7 @@ def test_set_attr_async_vector_env(shared_memory):
@pytest.mark.parametrize("shared_memory", [True, False])
def test_copy_async_vector_env(shared_memory):
"""Test observations are a copy of the true observation with and without shared memory."""
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
# TODO, these tests do nothing, understand the purpose of the tests and fix them
@@ -151,6 +159,7 @@ def test_copy_async_vector_env(shared_memory):
@pytest.mark.parametrize("shared_memory", [True, False])
def test_no_copy_async_vector_env(shared_memory):
"""Test observation are not a copy of the true observation with and without shared memory."""
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
# TODO, these tests do nothing, understand the purpose of the tests and fix them
@@ -163,6 +172,7 @@ def test_no_copy_async_vector_env(shared_memory):
@pytest.mark.parametrize("shared_memory", [True, False])
def test_reset_timeout_async_vector_env(shared_memory):
"""Test timeout error on reset with and without shared memory."""
env_fns = [make_slow_env(0.3, i) for i in range(4)]
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
@@ -175,18 +185,20 @@ def test_reset_timeout_async_vector_env(shared_memory):
@pytest.mark.parametrize("shared_memory", [True, False])
def test_step_timeout_async_vector_env(shared_memory):
"""Test timeout error on step with and without shared memory."""
env_fns = [make_slow_env(0.0, i) for i in range(4)]
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
with pytest.raises(TimeoutError):
env.reset()
env.step_async(np.array([0.1, 0.1, 0.3, 0.1]))
observations, rewards, terminateds, truncateds, _ = env.step_wait(timeout=0.1)
observations, rewards, terminations, truncations, _ = env.step_wait(timeout=0.1)
env.close(terminate=True)
@pytest.mark.parametrize("shared_memory", [True, False])
def test_reset_out_of_order_async_vector_env(shared_memory):
"""Test reset being called out of order with and without shared memory."""
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
@@ -223,6 +235,7 @@ def test_reset_out_of_order_async_vector_env(shared_memory):
@pytest.mark.parametrize("shared_memory", [True, False])
def test_step_out_of_order_async_vector_env(shared_memory):
"""Test step out of order with and without shared memory."""
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
@@ -258,6 +271,7 @@ def test_step_out_of_order_async_vector_env(shared_memory):
@pytest.mark.parametrize("shared_memory", [True, False])
def test_already_closed_async_vector_env(shared_memory):
"""Test the error if a function is called if environment is already closed."""
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
with pytest.raises(ClosedEnvironmentError):
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
@@ -267,6 +281,7 @@ def test_already_closed_async_vector_env(shared_memory):
@pytest.mark.parametrize("shared_memory", [True, False])
def test_check_spaces_async_vector_env(shared_memory):
"""Test check spaces for async vector environment with and without shared memory."""
# CartPole-v1 - observation_space: Box(4,), action_space: Discrete(2)
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
# FrozenLake-v1 - Discrete(16), action_space: Discrete(4)
@@ -277,6 +292,7 @@ def test_check_spaces_async_vector_env(shared_memory):
def test_custom_space_async_vector_env():
"""Test custom spaces with async vector environment."""
env_fns = [make_custom_space_env(i) for i in range(4)]
env = AsyncVectorEnv(env_fns, shared_memory=False)
@@ -286,7 +302,7 @@ def test_custom_space_async_vector_env():
assert isinstance(env.action_space, Tuple)
actions = ("action-2", "action-3", "action-5", "action-7")
step_observations, rewards, terminateds, truncateds, _ = env.step(actions)
step_observations, rewards, terminations, truncations, _ = env.step(actions)
env.close()
@@ -306,6 +322,7 @@ def test_custom_space_async_vector_env():
def test_custom_space_async_vector_env_shared_memory():
"""Test custom space with shared memory."""
env_fns = [make_custom_space_env(i) for i in range(4)]
with pytest.raises(ValueError):
env = AsyncVectorEnv(env_fns, shared_memory=True)

View File

@@ -1,3 +1,5 @@
"""Test the `SyncVectorEnv` implementation."""
import numpy as np
import pytest
@@ -14,6 +16,7 @@ from tests.vector.utils import (
def test_create_sync_vector_env():
"""Tests creating the sync vector environment."""
env_fns = [make_env("FrozenLake-v1", i) for i in range(8)]
env = SyncVectorEnv(env_fns)
env.close()
@@ -22,6 +25,7 @@ def test_create_sync_vector_env():
def test_reset_sync_vector_env():
"""Tests sync vector `reset` function."""
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
env = SyncVectorEnv(env_fns)
observations, infos = env.reset()
@@ -38,6 +42,7 @@ def test_reset_sync_vector_env():
@pytest.mark.parametrize("use_single_action_space", [True, False])
def test_step_sync_vector_env(use_single_action_space):
"""Test sync vector `steps` function."""
env_fns = [make_env("FrozenLake-v1", i) for i in range(8)]
env = SyncVectorEnv(env_fns)
@@ -77,6 +82,7 @@ def test_step_sync_vector_env(use_single_action_space):
def test_call_sync_vector_env():
"""Test sync vector `call` on sub-environments."""
env_fns = [
make_env("CartPole-v1", i, render_mode="rgb_array_list") for i in range(4)
]
@@ -102,6 +108,7 @@ def test_call_sync_vector_env():
def test_set_attr_sync_vector_env():
"""Test sync vector `set_attr` function."""
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
env = SyncVectorEnv(env_fns)
@@ -113,6 +120,7 @@ def test_set_attr_sync_vector_env():
def test_check_spaces_sync_vector_env():
"""Tests the sync vector `check_spaces` function."""
# CartPole-v1 - observation_space: Box(4,), action_space: Discrete(2)
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
# FrozenLake-v1 - Discrete(16), action_space: Discrete(4)
@@ -123,6 +131,7 @@ def test_check_spaces_sync_vector_env():
def test_custom_space_sync_vector_env():
"""Test the use of custom spaces with sync vector environment."""
env_fns = [make_custom_space_env(i) for i in range(4)]
env = SyncVectorEnv(env_fns)
@@ -152,6 +161,7 @@ def test_custom_space_sync_vector_env():
def test_sync_vector_env_seed():
"""Test seeding for sync vector environments."""
env = make_env("BipedalWalker-v3", seed=123)()
sync_vector_env = SyncVectorEnv([make_env("BipedalWalker-v3", seed=123)])
@@ -166,7 +176,7 @@ def test_sync_vector_env_seed():
"spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
)
def test_sync_vector_determinism(spec: EnvSpec, seed: int = 123, n: int = 3):
"""Check that for all environments, the sync vector envs produce the same action samples using the same seeds"""
"""Check that for all environments, the sync vector envs produce the same action samples using the same seeds."""
env_1 = SyncVectorEnv([make_env(spec.id, seed=seed) for _ in range(n)])
env_2 = SyncVectorEnv([make_env(spec.id, seed=seed) for _ in range(n)])
assert_rng_equal(env_1.action_space.np_random, env_2.action_space.np_random)

View File

@@ -1,3 +1,5 @@
"""Test vector environment implementations."""
from functools import partial
import numpy as np
@@ -11,6 +13,7 @@ from tests.vector.utils import make_env
@pytest.mark.parametrize("shared_memory", [True, False])
def test_vector_env_equal(shared_memory):
"""Test that vector environment are equal for both async and sync variants."""
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
num_steps = 100
@@ -31,12 +34,22 @@ def test_vector_env_equal(shared_memory):
actions = async_env.action_space.sample()
assert actions in sync_env.action_space
# fmt: off
async_observations, async_rewards, async_terminateds, async_truncateds, async_infos = async_env.step(actions)
sync_observations, sync_rewards, sync_terminateds, sync_truncateds, sync_infos = sync_env.step(actions)
# fmt: on
(
async_observations,
async_rewards,
async_terminations,
async_truncations,
async_infos,
) = async_env.step(actions)
(
sync_observations,
sync_rewards,
sync_terminations,
sync_truncations,
sync_infos,
) = sync_env.step(actions)
if any(sync_terminateds) or any(sync_truncateds):
if any(sync_terminations) or any(sync_truncations):
assert "final_observation" in async_infos
assert "_final_observation" in async_infos
assert "final_observation" in sync_infos
@@ -44,8 +57,8 @@ def test_vector_env_equal(shared_memory):
assert np.all(async_observations == sync_observations)
assert np.all(async_rewards == sync_rewards)
assert np.all(async_terminateds == sync_terminateds)
assert np.all(async_truncateds == sync_truncateds)
assert np.all(async_terminations == sync_terminations)
assert np.all(async_truncations == sync_truncations)
async_env.close()
sync_env.close()

View File

@@ -1,3 +1,4 @@
"""Test the vector environment information."""
import numpy as np
import pytest
@@ -12,12 +13,13 @@ ENV_STEPS = 50
SEED = 42
@pytest.mark.parametrize("asynchronous", [True, False])
def test_vector_env_info(asynchronous: bool):
@pytest.mark.parametrize("vectorization_mode", ["async", "sync"])
def test_vector_env_info(vectorization_mode: str):
"""Test vector environment info for different vectorization modes."""
env = gym.make_vec(
ENV_ID,
num_envs=NUM_ENVS,
vectorization_mode="async" if asynchronous else "sync",
vectorization_mode=vectorization_mode,
)
env.reset(seed=SEED)
for _ in range(ENV_STEPS):
@@ -41,6 +43,7 @@ def test_vector_env_info(asynchronous: bool):
@pytest.mark.parametrize("concurrent_ends", [1, 2, 3])
def test_vector_env_info_concurrent_termination(concurrent_ends):
"""Test the vector environment information works with concurrent termination."""
# envs that need to terminate together will have the same action
actions = [0] * concurrent_ends + [1] * (NUM_ENVS - concurrent_ends)
envs = [make_env(ENV_ID, SEED) for _ in range(NUM_ENVS)]

View File

@@ -1,23 +1,29 @@
"""Tests the vector wrappers work as expected."""
import numpy as np
import gymnasium as gym
from gymnasium.experimental.vector import VectorWrapper
class DummyWrapper(VectorWrapper):
class DummyVectorWrapper(VectorWrapper):
"""Dummy Vector wrapper that contains a counter function to logging the number of times that reset is called."""
def __init__(self, env):
"""Initialises the wrapper with the environment creating a counter variable."""
super().__init__(env)
self.env = env
self.counter = 0
def reset(self, **kwargs):
"""Updates the ``counter`` each time at ``reset`` is called."""
super().reset()
self.counter += 1
def test_vector_env_wrapper_inheritance():
"""Test vector environment wrapper inheritance."""
env = gym.make_vec("FrozenLake-v1", vectorization_mode="async")
wrapped = DummyWrapper(env)
wrapped = DummyVectorWrapper(env)
wrapped.reset()
assert wrapped.counter == 1
@@ -25,7 +31,7 @@ def test_vector_env_wrapper_inheritance():
def test_vector_env_wrapper_attributes():
"""Test if `set_attr`, `call` methods for VecEnvWrapper get correctly forwarded to the vector env it is wrapping."""
env = gym.make_vec("CartPole-v1", num_envs=3)
wrapped = DummyWrapper(gym.make_vec("CartPole-v1", num_envs=3))
wrapped = DummyVectorWrapper(gym.make_vec("CartPole-v1", num_envs=3))
assert np.allclose(wrapped.call("gravity"), env.call("gravity"))
env.set_attr("gravity", [20.0, 20.0, 20.0])

View File

@@ -0,0 +1,163 @@
"""Testing utilitys for `gymnasium.experimental.vector`."""
import time
from typing import Optional
import numpy as np
import gymnasium as gym
from gymnasium.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple
BaseGymSpaces = (Box, Discrete, MultiDiscrete, MultiBinary)
spaces = [
Box(low=np.array(-1.0), high=np.array(1.0), dtype=np.float64),
Box(low=np.array([0.0]), high=np.array([10.0]), dtype=np.float64),
Box(
low=np.array([-1.0, 0.0, 0.0]), high=np.array([1.0, 1.0, 1.0]), dtype=np.float64
),
Box(
low=np.array([[-1.0, 0.0], [0.0, -1.0]]), high=np.ones((2, 2)), dtype=np.float64
),
Box(low=0, high=255, shape=(), dtype=np.uint8),
Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
Discrete(2),
Discrete(5, start=-2),
Tuple((Discrete(3), Discrete(5))),
Tuple(
(
Discrete(7),
Box(low=np.array([0.0, -1.0]), high=np.array([1.0, 1.0]), dtype=np.float64),
)
),
MultiDiscrete([11, 13, 17]),
MultiBinary(19),
Dict(
{
"position": Discrete(23),
"velocity": Box(
low=np.array([0.0]), high=np.array([1.0]), dtype=np.float64
),
}
),
Dict(
{
"position": Dict({"x": Discrete(29), "y": Discrete(31)}),
"velocity": Tuple(
(Discrete(37), Box(low=0, high=255, shape=(), dtype=np.uint8))
),
}
),
]
HEIGHT, WIDTH = 64, 64
class SlowEnv(gym.Env):
"""A custom slow environment."""
def __init__(self, slow_reset=0.3):
"""Initialises the environment with a slow reset parameter used in the `step` and `reset` functions."""
super().__init__()
self.slow_reset = slow_reset
self.observation_space = Box(
low=0, high=255, shape=(HEIGHT, WIDTH, 3), dtype=np.uint8
)
self.action_space = Box(low=0.0, high=1.0, shape=(), dtype=np.float32)
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
"""Resets the environment with a time sleep."""
super().reset(seed=seed)
if self.slow_reset > 0:
time.sleep(self.slow_reset)
return self.observation_space.sample(), {}
def step(self, action):
"""Steps through the environment with a time sleep."""
time.sleep(action)
observation = self.observation_space.sample()
reward, terminated, truncated = 0.0, False, False
return observation, reward, terminated, truncated, {}
class CustomSpace(gym.Space):
"""Minimal custom observation space."""
def sample(self):
"""Generates a sample from the custom space."""
return self.np_random.integers(0, 10, ())
def contains(self, x):
"""Check if the element `x` is contained within the space."""
return 0 <= x <= 10
def __eq__(self, other):
"""Check if the two spaces are equal."""
return isinstance(other, CustomSpace)
custom_spaces = [
CustomSpace(),
Tuple((CustomSpace(), Box(low=0, high=255, shape=(), dtype=np.uint8))),
]
class CustomSpaceEnv(gym.Env):
"""An environment with custom spaces for observation and action spaces."""
def __init__(self):
"""Initialise the environment."""
super().__init__()
self.observation_space = CustomSpace()
self.action_space = CustomSpace()
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
"""Resets the environment."""
super().reset(seed=seed)
return "reset", {}
def step(self, action):
"""Steps through the environment."""
observation = f"step({action:s})"
reward, terminated, truncated = 0.0, False, False
return observation, reward, terminated, truncated, {}
def make_env(env_name, seed, **kwargs):
"""Creates an environment."""
def _make():
env = gym.make(env_name, disable_env_checker=True, **kwargs)
env.action_space.seed(seed)
env.reset(seed=seed)
return env
return _make
def make_slow_env(slow_reset, seed):
"""Creates an environment with slow reset."""
def _make():
env = SlowEnv(slow_reset=slow_reset)
env.reset(seed=seed)
return env
return _make
def make_custom_space_env(seed):
"""Creates a custom space environment."""
def _make():
env = CustomSpaceEnv()
env.reset(seed=seed)
return env
return _make
def assert_rng_equal(rng_1: np.random.Generator, rng_2: np.random.Generator):
"""Tests whether two random number generators are equal."""
assert rng_1.bit_generator.state == rng_2.bit_generator.state

View File

@@ -0,0 +1 @@
"""Module for testing `gymnasium.experimental.vector.utils` functions."""

View File

@@ -0,0 +1,187 @@
"""Tests `gymnasium.experimental.vector.utils.shared_memory functions."""
import multiprocessing as mp
from collections import OrderedDict
from multiprocessing import Array, Process
from multiprocessing.sharedctypes import SynchronizedArray
import numpy as np
import pytest
from gymnasium.error import CustomSpaceError
from gymnasium.experimental.vector.utils import (
create_shared_memory,
read_from_shared_memory,
write_to_shared_memory,
)
from gymnasium.spaces import Dict, Tuple
from gymnasium.vector.utils import BaseGymSpaces
from tests.experimental.vector.testing_utils import custom_spaces, spaces
expected_types = [
Array("d", 1),
Array("f", 1),
Array("f", 3),
Array("f", 4),
Array("B", 1),
Array("B", 32 * 32 * 3),
Array("i", 1),
Array("i", 1),
(Array("i", 1), Array("i", 1)),
(Array("i", 1), Array("f", 2)),
Array("B", 3),
Array("B", 19),
OrderedDict([("position", Array("i", 1)), ("velocity", Array("f", 1))]),
OrderedDict(
[
("position", OrderedDict([("x", Array("i", 1)), ("y", Array("i", 1))])),
("velocity", (Array("i", 1), Array("B", 1))),
]
),
]
@pytest.mark.parametrize("n", [1, 8])
@pytest.mark.parametrize(
"space,expected_type",
list(zip(spaces, expected_types)),
ids=[space.__class__.__name__ for space in spaces],
)
@pytest.mark.parametrize(
"ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"]
)
def test_create_shared_memory(space, expected_type, n, ctx):
"""Tests the `create_shared_memory` function with a number of spaces."""
def assert_nested_type(lhs, rhs, n):
assert type(lhs) == type(rhs)
if isinstance(lhs, (list, tuple)):
assert len(lhs) == len(rhs)
for lhs_, rhs_ in zip(lhs, rhs):
assert_nested_type(lhs_, rhs_, n)
elif isinstance(lhs, (dict, OrderedDict)):
assert set(lhs.keys()) ^ set(rhs.keys()) == set()
for key in lhs.keys():
assert_nested_type(lhs[key], rhs[key], n)
elif isinstance(lhs, SynchronizedArray):
# Assert the length of the array
assert len(lhs[:]) == n * len(rhs[:])
# Assert the data type
assert isinstance(lhs[0], type(rhs[0]))
else:
raise TypeError(f"Got unknown type `{type(lhs)}`.")
ctx = mp if (ctx is None) else mp.get_context(ctx)
shared_memory = create_shared_memory(space, n=n, ctx=ctx)
assert_nested_type(shared_memory, expected_type, n=n)
@pytest.mark.parametrize("n", [1, 8])
@pytest.mark.parametrize(
"ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"]
)
@pytest.mark.parametrize("space", custom_spaces)
def test_create_shared_memory_custom_space(n, ctx, space):
"""Tests the `create_shared_memory` function with a custom space."""
ctx = mp if (ctx is None) else mp.get_context(ctx)
with pytest.raises(CustomSpaceError):
create_shared_memory(space, n=n, ctx=ctx)
def _write_shared_memory(space, i, shared_memory, sample):
write_to_shared_memory(space, i, sample, shared_memory)
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_write_to_shared_memory(space):
"""Tests `write_to_shared_memory` function with a list of spaces."""
def assert_nested_equal(lhs, rhs):
assert isinstance(rhs, list)
if isinstance(lhs, (list, tuple)):
for i in range(len(lhs)):
assert_nested_equal(lhs[i], [rhs_[i] for rhs_ in rhs])
elif isinstance(lhs, (dict, OrderedDict)):
for key in lhs.keys():
assert_nested_equal(lhs[key], [rhs_[key] for rhs_ in rhs])
elif isinstance(lhs, SynchronizedArray):
assert np.all(np.array(lhs[:]) == np.stack(rhs, axis=0).flatten())
else:
raise TypeError(f"Got unknown type `{type(lhs)}`.")
shared_memory_n8 = create_shared_memory(space, n=8)
samples = [space.sample() for _ in range(8)]
processes = [
Process(
target=_write_shared_memory, args=(space, i, shared_memory_n8, samples[i])
)
for i in range(8)
]
for process in processes:
process.start()
for process in processes:
process.join()
assert_nested_equal(shared_memory_n8, samples)
def _process_write(space, i, shared_memory, sample):
write_to_shared_memory(space, i, sample, shared_memory)
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_read_from_shared_memory(space):
"""Tests `read_from_shared_memory` function with list of spaces."""
def assert_nested_equal(lhs, rhs, space, n):
assert isinstance(rhs, list)
if isinstance(space, Tuple):
assert isinstance(lhs, tuple)
for i in range(len(lhs)):
assert_nested_equal(
lhs[i], [rhs_[i] for rhs_ in rhs], space.spaces[i], n
)
elif isinstance(space, Dict):
assert isinstance(lhs, OrderedDict)
for key in lhs.keys():
assert_nested_equal(
lhs[key], [rhs_[key] for rhs_ in rhs], space.spaces[key], n
)
elif isinstance(space, BaseGymSpaces):
assert isinstance(lhs, np.ndarray)
assert lhs.shape == ((n,) + space.shape)
assert lhs.dtype == space.dtype
assert np.all(lhs == np.stack(rhs, axis=0))
else:
raise TypeError(f"Got unknown type `{type(space)}`")
shared_memory_n8 = create_shared_memory(space, n=8)
memory_view_n8 = read_from_shared_memory(space, shared_memory_n8, n=8)
samples = [space.sample() for _ in range(8)]
processes = [
Process(target=_process_write, args=(space, i, shared_memory_n8, samples[i]))
for i in range(8)
]
for process in processes:
process.start()
for process in processes:
process.join()
assert_nested_equal(memory_view_n8, samples, space, n=8)

View File

@@ -0,0 +1,356 @@
"""Testing `gymnasium.experimental.vector.utils.space_utils` functions."""
import copy
from collections import OrderedDict
import numpy as np
import pytest
from numpy.testing import assert_array_equal
from gymnasium.experimental.vector.utils import (
batch_space,
concatenate,
create_empty_array,
iterate,
)
from gymnasium.spaces import Box, Dict, MultiDiscrete, Space, Tuple
from tests.experimental.vector.testing_utils import (
BaseGymSpaces,
CustomSpace,
assert_rng_equal,
custom_spaces,
spaces,
)
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_concatenate(space):
"""Tests the `concatenate` functions with list of spaces."""
def assert_type(lhs, rhs, n):
# Special case: if rhs is a list of scalars, lhs must be an np.ndarray
if np.isscalar(rhs[0]):
assert isinstance(lhs, np.ndarray)
assert all([np.isscalar(rhs[i]) for i in range(n)])
else:
assert all([isinstance(rhs[i], type(lhs)) for i in range(n)])
def assert_nested_equal(lhs, rhs, n):
assert isinstance(rhs, list)
assert (n > 0) and (len(rhs) == n)
assert_type(lhs, rhs, n)
if isinstance(lhs, np.ndarray):
assert lhs.shape[0] == n
for i in range(n):
assert np.all(lhs[i] == rhs[i])
elif isinstance(lhs, tuple):
for i in range(len(lhs)):
rhs_T_i = [rhs[j][i] for j in range(n)]
assert_nested_equal(lhs[i], rhs_T_i, n)
elif isinstance(lhs, OrderedDict):
for key in lhs.keys():
rhs_T_key = [rhs[j][key] for j in range(n)]
assert_nested_equal(lhs[key], rhs_T_key, n)
else:
raise TypeError(f"Got unknown type `{type(lhs)}`.")
samples = [space.sample() for _ in range(8)]
array = create_empty_array(space, n=8)
concatenated = concatenate(space, samples, array)
assert np.all(concatenated == array)
assert_nested_equal(array, samples, n=8)
@pytest.mark.parametrize("n", [1, 8])
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_create_empty_array(space, n):
"""Test `create_empty_array` function with list of spaces and different `n` values."""
def assert_nested_type(arr, space, n):
if isinstance(space, BaseGymSpaces):
assert isinstance(arr, np.ndarray)
assert arr.dtype == space.dtype
assert arr.shape == (n,) + space.shape
elif isinstance(space, Tuple):
assert isinstance(arr, tuple)
assert len(arr) == len(space.spaces)
for i in range(len(arr)):
assert_nested_type(arr[i], space.spaces[i], n)
elif isinstance(space, Dict):
assert isinstance(arr, OrderedDict)
assert set(arr.keys()) ^ set(space.spaces.keys()) == set()
for key in arr.keys():
assert_nested_type(arr[key], space.spaces[key], n)
else:
raise TypeError(f"Got unknown type `{type(arr)}`.")
array = create_empty_array(space, n=n, fn=np.empty)
assert_nested_type(array, space, n=n)
@pytest.mark.parametrize("n", [1, 8])
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_create_empty_array_zeros(space, n):
"""Test `create_empty_array` with a list of spaces and different `n`."""
def assert_nested_type(arr, space, n):
if isinstance(space, BaseGymSpaces):
assert isinstance(arr, np.ndarray)
assert arr.dtype == space.dtype
assert arr.shape == (n,) + space.shape
assert np.all(arr == 0)
elif isinstance(space, Tuple):
assert isinstance(arr, tuple)
assert len(arr) == len(space.spaces)
for i in range(len(arr)):
assert_nested_type(arr[i], space.spaces[i], n)
elif isinstance(space, Dict):
assert isinstance(arr, OrderedDict)
assert set(arr.keys()) ^ set(space.spaces.keys()) == set()
for key in arr.keys():
assert_nested_type(arr[key], space.spaces[key], n)
else:
raise TypeError(f"Got unknown type `{type(arr)}`.")
array = create_empty_array(space, n=n, fn=np.zeros)
assert_nested_type(array, space, n=n)
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
def test_create_empty_array_none_shape_ones(space):
"""Tests `create_empty_array` with ``None`` space."""
def assert_nested_type(arr, space):
if isinstance(space, BaseGymSpaces):
assert isinstance(arr, np.ndarray)
assert arr.dtype == space.dtype
assert arr.shape == space.shape
assert np.all(arr == 1)
elif isinstance(space, Tuple):
assert isinstance(arr, tuple)
assert len(arr) == len(space.spaces)
for i in range(len(arr)):
assert_nested_type(arr[i], space.spaces[i])
elif isinstance(space, Dict):
assert isinstance(arr, OrderedDict)
assert set(arr.keys()) ^ set(space.spaces.keys()) == set()
for key in arr.keys():
assert_nested_type(arr[key], space.spaces[key])
else:
raise TypeError(f"Got unknown type `{type(arr)}`.")
array = create_empty_array(space, n=None, fn=np.ones)
assert_nested_type(array, space)
expected_batch_spaces_4 = [
Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float64),
Box(low=0.0, high=10.0, shape=(4, 1), dtype=np.float64),
Box(
low=np.array(
[[-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]]
),
high=np.array(
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
),
dtype=np.float64,
),
Box(
low=np.array(
[
[[-1.0, 0.0], [0.0, -1.0]],
[[-1.0, 0.0], [0.0, -1.0]],
[[-1.0, 0.0], [0.0, -1]],
[[-1.0, 0.0], [0.0, -1.0]],
]
),
high=np.ones((4, 2, 2)),
dtype=np.float64,
),
Box(low=0, high=255, shape=(4,), dtype=np.uint8),
Box(low=0, high=255, shape=(4, 32, 32, 3), dtype=np.uint8),
MultiDiscrete([2, 2, 2, 2]),
Box(low=-2, high=2, shape=(4,), dtype=np.int64),
Tuple((MultiDiscrete([3, 3, 3, 3]), MultiDiscrete([5, 5, 5, 5]))),
Tuple(
(
MultiDiscrete([7, 7, 7, 7]),
Box(
low=np.array([[0.0, -1.0], [0.0, -1.0], [0.0, -1.0], [0.0, -1]]),
high=np.array([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]),
dtype=np.float64,
),
)
),
Box(
low=np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]),
high=np.array([[10, 12, 16], [10, 12, 16], [10, 12, 16], [10, 12, 16]]),
dtype=np.int64,
),
Box(low=0, high=1, shape=(4, 19), dtype=np.int8),
Dict(
{
"position": MultiDiscrete([23, 23, 23, 23]),
"velocity": Box(low=0.0, high=1.0, shape=(4, 1), dtype=np.float64),
}
),
Dict(
{
"position": Dict(
{
"x": MultiDiscrete([29, 29, 29, 29]),
"y": MultiDiscrete([31, 31, 31, 31]),
}
),
"velocity": Tuple(
(
MultiDiscrete([37, 37, 37, 37]),
Box(low=0, high=255, shape=(4,), dtype=np.uint8),
)
),
}
),
]
expected_custom_batch_spaces_4 = [
Tuple((CustomSpace(), CustomSpace(), CustomSpace(), CustomSpace())),
Tuple(
(
Tuple((CustomSpace(), CustomSpace(), CustomSpace(), CustomSpace())),
Box(low=0, high=255, shape=(4,), dtype=np.uint8),
)
),
]
@pytest.mark.parametrize(
"space,expected_batch_space_4",
list(zip(spaces, expected_batch_spaces_4)),
ids=[space.__class__.__name__ for space in spaces],
)
def test_batch_space(space, expected_batch_space_4):
"""Tests `batch_space` with the expected spaces."""
batch_space_4 = batch_space(space, n=4)
assert batch_space_4 == expected_batch_space_4
@pytest.mark.parametrize(
"space,expected_batch_space_4",
list(zip(custom_spaces, expected_custom_batch_spaces_4)),
ids=[space.__class__.__name__ for space in custom_spaces],
)
def test_batch_space_custom_space(space, expected_batch_space_4):
"""Tests `batch_space` for custom spaces with the expected batch spaces."""
batch_space_4 = batch_space(space, n=4)
assert batch_space_4 == expected_batch_space_4
@pytest.mark.parametrize(
"space,batched_space",
list(zip(spaces, expected_batch_spaces_4)),
ids=[space.__class__.__name__ for space in spaces],
)
def test_iterate(space, batched_space):
"""Test `iterate` function with list of spaces and expected batch space."""
items = batched_space.sample()
iterator = iterate(batched_space, items)
i = 0
for i, item in enumerate(iterator):
assert item in space
assert i == 3
@pytest.mark.parametrize(
"space,batched_space",
list(zip(custom_spaces, expected_custom_batch_spaces_4)),
ids=[space.__class__.__name__ for space in custom_spaces],
)
def test_iterate_custom_space(space, batched_space):
"""Test iterating over a custom space."""
items = batched_space.sample()
iterator = iterate(batched_space, items)
i = 0
for i, item in enumerate(iterator):
assert item in space
assert i == 3
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
@pytest.mark.parametrize("n", [4, 5], ids=[f"n={n}" for n in [4, 5]])
@pytest.mark.parametrize(
"base_seed", [123, 456], ids=[f"seed={base_seed}" for base_seed in [123, 456]]
)
def test_rng_different_at_each_index(space: Space, n: int, base_seed: int):
"""Tests that the rng values produced at each index are different to prevent if the rng is copied for each subspace."""
space.seed(base_seed)
batched_space = batch_space(space, n)
assert space.np_random is not batched_space.np_random
assert_rng_equal(space.np_random, batched_space.np_random)
batched_sample = batched_space.sample()
sample = list(iterate(batched_space, batched_sample))
assert not all(np.all(element == sample[0]) for element in sample), sample
@pytest.mark.parametrize(
"space", spaces, ids=[space.__class__.__name__ for space in spaces]
)
@pytest.mark.parametrize("n", [1, 2, 5], ids=[f"n={n}" for n in [1, 2, 5]])
@pytest.mark.parametrize(
"base_seed", [123, 456], ids=[f"seed={base_seed}" for base_seed in [123, 456]]
)
def test_deterministic(space: Space, n: int, base_seed: int):
"""Tests the batched spaces are deterministic by using a copied version."""
# Copy the spaces and check that the np_random are not reference equal
space_a = space
space_a.seed(base_seed)
space_b = copy.deepcopy(space_a)
assert_rng_equal(space_a.np_random, space_b.np_random)
assert space_a.np_random is not space_b.np_random
# Batch the spaces and check that the np_random are not reference equal
space_a_batched = batch_space(space_a, n)
space_b_batched = batch_space(space_b, n)
assert_rng_equal(space_a_batched.np_random, space_b_batched.np_random)
assert space_a_batched.np_random is not space_b_batched.np_random
# Create that the batched space is not reference equal to the origin spaces
assert space_a.np_random is not space_a_batched.np_random
# Check that batched space a and b random number generator are not effected by the original space
space_a.sample()
space_a_batched_sample = space_a_batched.sample()
space_b_batched_sample = space_b_batched.sample()
for a_sample, b_sample in zip(
iterate(space_a_batched, space_a_batched_sample),
iterate(space_b_batched, space_b_batched_sample),
):
if isinstance(a_sample, tuple):
assert len(a_sample) == len(b_sample)
for a_subsample, b_subsample in zip(a_sample, b_sample):
assert_array_equal(a_subsample, b_subsample)
else:
assert_array_equal(a_sample, b_sample)