mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 22:04:31 +00:00
* feat: add `isort` to `pre-commit` * ci: skip `__init__.py` file for `isort` * ci: make `isort` mandatory in lint pipeline * docs: add a section on Git hooks * ci: check isort diff * fix: isort from master branch * docs: add pre-commit badge * ci: update black + bandit versions * feat: add PR template * refactor: PR template * ci: remove bandit * docs: add Black badge * ci: try to remove all `|| true` statements * ci: remove lint_python job - Remove `lint_python` CI job - Move `pyupgrade` job to `pre-commit` workflow * fix: avoid messing with typing * docs: add a note on running `pre-cpmmit` manually * ci: apply `pre-commit` to the whole codebase
188 lines
5.3 KiB
Python
188 lines
5.3 KiB
Python
from collections import OrderedDict
|
|
from functools import singledispatch
|
|
|
|
import numpy as np
|
|
|
|
from gym.error import CustomSpaceError
|
|
from gym.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Space, Tuple
|
|
|
|
_BaseGymSpaces = (Box, Discrete, MultiDiscrete, MultiBinary)
|
|
__all__ = ["_BaseGymSpaces", "batch_space", "iterate"]
|
|
|
|
|
|
@singledispatch
|
|
def batch_space(space, n=1):
|
|
"""Create a (batched) space, containing multiple copies of a single space.
|
|
|
|
Parameters
|
|
----------
|
|
space : `gym.spaces.Space` instance
|
|
Space (e.g. the observation space) for a single environment in the
|
|
vectorized environment.
|
|
|
|
n : int
|
|
Number of environments in the vectorized environment.
|
|
|
|
Returns
|
|
-------
|
|
batched_space : `gym.spaces.Space` instance
|
|
Space (e.g. the observation space) for a batch of environments in the
|
|
vectorized environment.
|
|
|
|
Example
|
|
-------
|
|
>>> from gym.spaces import Box, Dict
|
|
>>> space = Dict({
|
|
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
|
|
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
|
|
>>> batch_space(space, n=5)
|
|
Dict(position:Box(5, 3), velocity:Box(5, 2))
|
|
"""
|
|
raise ValueError(
|
|
f"Cannot batch space with type `{type(space)}`. The space must be a valid `gym.Space` instance."
|
|
)
|
|
|
|
|
|
@batch_space.register(Box)
|
|
def _batch_space_box(space, n=1):
|
|
repeats = tuple([n] + [1] * space.low.ndim)
|
|
low, high = np.tile(space.low, repeats), np.tile(space.high, repeats)
|
|
return Box(low=low, high=high, dtype=space.dtype)
|
|
|
|
|
|
@batch_space.register(Discrete)
|
|
def _batch_space_discrete(space, n=1):
|
|
if space.start == 0:
|
|
return MultiDiscrete(np.full((n,), space.n, dtype=space.dtype))
|
|
else:
|
|
return Box(
|
|
low=space.start,
|
|
high=space.start + space.n - 1,
|
|
shape=(n,),
|
|
dtype=space.dtype,
|
|
)
|
|
|
|
|
|
@batch_space.register(MultiDiscrete)
|
|
def _batch_space_multidiscrete(space, n=1):
|
|
repeats = tuple([n] + [1] * space.nvec.ndim)
|
|
high = np.tile(space.nvec, repeats) - 1
|
|
return Box(low=np.zeros_like(high), high=high, dtype=space.dtype)
|
|
|
|
|
|
@batch_space.register(MultiBinary)
|
|
def _batch_space_multibinary(space, n=1):
|
|
return Box(low=0, high=1, shape=(n,) + space.shape, dtype=space.dtype)
|
|
|
|
|
|
@batch_space.register(Tuple)
|
|
def _batch_space_tuple(space, n=1):
|
|
return Tuple(tuple(batch_space(subspace, n=n) for subspace in space.spaces))
|
|
|
|
|
|
@batch_space.register(Dict)
|
|
def _batch_space_dict(space, n=1):
|
|
return Dict(
|
|
OrderedDict(
|
|
[
|
|
(key, batch_space(subspace, n=n))
|
|
for (key, subspace) in space.spaces.items()
|
|
]
|
|
)
|
|
)
|
|
|
|
|
|
@batch_space.register(Space)
|
|
def _batch_space_custom(space, n=1):
|
|
return Tuple(tuple(space for _ in range(n)))
|
|
|
|
|
|
@singledispatch
|
|
def iterate(space, items):
|
|
"""Iterate over the elements of a (batched) space.
|
|
|
|
Parameters
|
|
----------
|
|
space : `gym.spaces.Space` instance
|
|
Space to which `items` belong to.
|
|
|
|
items : samples of `space`
|
|
Items to be iterated over.
|
|
|
|
Returns
|
|
-------
|
|
iterator : `Iterable` instance
|
|
Iterator over the elements in `items`.
|
|
|
|
Example
|
|
-------
|
|
>>> from gym.spaces import Box, Dict
|
|
>>> space = Dict({
|
|
... 'position': Box(low=0, high=1, shape=(2, 3), dtype=np.float32),
|
|
... 'velocity': Box(low=0, high=1, shape=(2, 2), dtype=np.float32)})
|
|
>>> items = space.sample()
|
|
>>> it = iterate(space, items)
|
|
>>> next(it)
|
|
{'position': array([-0.99644893, -0.08304597, -0.7238421 ], dtype=float32),
|
|
'velocity': array([0.35848552, 0.1533453 ], dtype=float32)}
|
|
>>> next(it)
|
|
{'position': array([-0.67958736, -0.49076623, 0.38661423], dtype=float32),
|
|
'velocity': array([0.7975036 , 0.93317133], dtype=float32)}
|
|
>>> next(it)
|
|
StopIteration
|
|
"""
|
|
raise ValueError(
|
|
"Space of type `{}` is not a valid `gym.Space` " "instance.".format(type(space))
|
|
)
|
|
|
|
|
|
@iterate.register(Discrete)
|
|
def _iterate_discrete(space, items):
|
|
raise TypeError("Unable to iterate over a space of type `Discrete`.")
|
|
|
|
|
|
@iterate.register(Box)
|
|
@iterate.register(MultiDiscrete)
|
|
@iterate.register(MultiBinary)
|
|
def _iterate_base(space, items):
|
|
try:
|
|
return iter(items)
|
|
except TypeError:
|
|
raise TypeError(f"Unable to iterate over the following elements: {items}")
|
|
|
|
|
|
@iterate.register(Tuple)
|
|
def _iterate_tuple(space, items):
|
|
# If this is a tuple of custom subspaces only, then simply iterate over items
|
|
if all(
|
|
isinstance(subspace, Space)
|
|
and (not isinstance(subspace, _BaseGymSpaces + (Tuple, Dict)))
|
|
for subspace in space.spaces
|
|
):
|
|
return iter(items)
|
|
|
|
return zip(
|
|
*[iterate(subspace, items[i]) for i, subspace in enumerate(space.spaces)]
|
|
)
|
|
|
|
|
|
@iterate.register(Dict)
|
|
def _iterate_dict(space, items):
|
|
keys, values = zip(
|
|
*[
|
|
(key, iterate(subspace, items[key]))
|
|
for key, subspace in space.spaces.items()
|
|
]
|
|
)
|
|
for item in zip(*values):
|
|
yield OrderedDict([(key, value) for (key, value) in zip(keys, item)])
|
|
|
|
|
|
@iterate.register(Space)
|
|
def _iterate_custom(space, items):
|
|
raise CustomSpaceError(
|
|
f"Unable to iterate over {items}, since {space} "
|
|
"is a custom `gym.Space` instance (i.e. not one of "
|
|
"`Box`, `Dict`, etc...)."
|
|
)
|