Added singledispatch utility to vector.utils & changed order of space argument. (#2536)

* Fixed ordering of space. Added singledispatch utility.

* Added singledispatch utility to vector.utils & changed order of space argument

* Fixed Error from _BaseGymSpaces

* Minor adjustment for Discrete Spaces

* Fixed Tests/ to reflect changes

* Fixed precommit error - custom namespaces

* Concrete Implementations start with _
This commit is contained in:
Rushiv Arora
2022-01-21 11:28:34 -05:00
committed by GitHub
parent 925823661d
commit fcbff7de12
8 changed files with 139 additions and 136 deletions

View File

@@ -9,6 +9,7 @@ _BaseGymSpaces = (Box, Discrete, MultiDiscrete, MultiBinary)
__all__ = ["_BaseGymSpaces", "batch_space", "iterate"]
@singledispatch
def batch_space(space, n=1):
"""Create a (batched) space, containing multiple copies of a single space.
@@ -36,20 +37,15 @@ def batch_space(space, n=1):
>>> batch_space(space, n=5)
Dict(position:Box(5, 3), velocity:Box(5, 2))
"""
if isinstance(space, _BaseGymSpaces):
return batch_space_base(space, n=n)
elif isinstance(space, Tuple):
return batch_space_tuple(space, n=n)
elif isinstance(space, Dict):
return batch_space_dict(space, n=n)
elif isinstance(space, Space):
return batch_space_custom(space, n=n)
else:
raise ValueError(
f"Cannot batch space with type `{type(space)}`. The space must be a valid `gym.Space` instance."
)
raise ValueError(
f"Cannot batch space with type `{type(space)}`. The space must be a valid `gym.Space` instance."
)
@batch_space.register(Box)
@batch_space.register(Discrete)
@batch_space.register(MultiDiscrete)
@batch_space.register(MultiBinary)
def batch_space_base(space, n=1):
if isinstance(space, Box):
repeats = tuple([n] + [1] * space.low.ndim)
@@ -71,10 +67,12 @@ def batch_space_base(space, n=1):
raise ValueError(f"Space type `{type(space)}` is not supported.")
@batch_space.register(Tuple)
def batch_space_tuple(space, n=1):
return Tuple(tuple(batch_space(subspace, n=n) for subspace in space.spaces))
@batch_space.register(Dict)
def batch_space_dict(space, n=1):
return Dict(
OrderedDict(
@@ -86,6 +84,7 @@ def batch_space_dict(space, n=1):
)
@batch_space.register(Space)
def batch_space_custom(space, n=1):
return Tuple(tuple(space for _ in range(n)))