mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-09-01 10:27:43 +00:00
Added singledispatch utility to vector.utils & changed order of space argument. (#2536)
* Fixed ordering of space. Added singledispatch utility. * Added singledispatch utility to vector.utils & changed order of space argument * Fixed Error from _BaseGymSpaces * Minor adjustment for Discrete Spaces * Fixed Tests/ to reflect changes * Fixed precommit error - custom namespaces * Concrete Implementations start with _
This commit is contained in:
@@ -9,6 +9,7 @@ _BaseGymSpaces = (Box, Discrete, MultiDiscrete, MultiBinary)
|
||||
__all__ = ["_BaseGymSpaces", "batch_space", "iterate"]
|
||||
|
||||
|
||||
@singledispatch
|
||||
def batch_space(space, n=1):
|
||||
"""Create a (batched) space, containing multiple copies of a single space.
|
||||
|
||||
@@ -36,20 +37,15 @@ def batch_space(space, n=1):
|
||||
>>> batch_space(space, n=5)
|
||||
Dict(position:Box(5, 3), velocity:Box(5, 2))
|
||||
"""
|
||||
if isinstance(space, _BaseGymSpaces):
|
||||
return batch_space_base(space, n=n)
|
||||
elif isinstance(space, Tuple):
|
||||
return batch_space_tuple(space, n=n)
|
||||
elif isinstance(space, Dict):
|
||||
return batch_space_dict(space, n=n)
|
||||
elif isinstance(space, Space):
|
||||
return batch_space_custom(space, n=n)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot batch space with type `{type(space)}`. The space must be a valid `gym.Space` instance."
|
||||
)
|
||||
raise ValueError(
|
||||
f"Cannot batch space with type `{type(space)}`. The space must be a valid `gym.Space` instance."
|
||||
)
|
||||
|
||||
|
||||
@batch_space.register(Box)
|
||||
@batch_space.register(Discrete)
|
||||
@batch_space.register(MultiDiscrete)
|
||||
@batch_space.register(MultiBinary)
|
||||
def batch_space_base(space, n=1):
|
||||
if isinstance(space, Box):
|
||||
repeats = tuple([n] + [1] * space.low.ndim)
|
||||
@@ -71,10 +67,12 @@ def batch_space_base(space, n=1):
|
||||
raise ValueError(f"Space type `{type(space)}` is not supported.")
|
||||
|
||||
|
||||
@batch_space.register(Tuple)
|
||||
def batch_space_tuple(space, n=1):
|
||||
return Tuple(tuple(batch_space(subspace, n=n) for subspace in space.spaces))
|
||||
|
||||
|
||||
@batch_space.register(Dict)
|
||||
def batch_space_dict(space, n=1):
|
||||
return Dict(
|
||||
OrderedDict(
|
||||
@@ -86,6 +84,7 @@ def batch_space_dict(space, n=1):
|
||||
)
|
||||
|
||||
|
||||
@batch_space.register(Space)
|
||||
def batch_space_custom(space, n=1):
|
||||
return Tuple(tuple(space for _ in range(n)))
|
||||
|
||||
|
Reference in New Issue
Block a user