Added singledispatch utility to vector.utils & changed order of space argument. (#2536)

* Fixed ordering of space. Added singledispatch utility.

* Added singledispatch utility to vector.utils & changed order of space argument

* Fixed Error from _BaseGymSpaces

* Minor adjustment for Discrete Spaces

* Fixed Tests/ to reflect changes

* Fixed precommit error - custom namespaces

* Concrete Implementations start with _
This commit is contained in:
Rushiv Arora
2022-01-21 11:28:34 -05:00
committed by GitHub
parent 925823661d
commit fcbff7de12
8 changed files with 139 additions and 136 deletions

View File

@@ -1,13 +1,16 @@
import numpy as np
from gym.spaces import Space, Tuple, Dict
from gym.spaces import Space, Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict
from gym.vector.utils.spaces import _BaseGymSpaces
from collections import OrderedDict
from functools import singledispatch
__all__ = ["concatenate", "create_empty_array"]
def concatenate(items, out, space):
@singledispatch
def concatenate(space, items, out):
"""Concatenate multiple samples from space into a single object.
Parameters
@@ -37,44 +40,43 @@ def concatenate(items, out, space):
[0.87383074, 0.192658 , 0.2148103 ]], dtype=float32)
"""
assert isinstance(items, (list, tuple))
if isinstance(space, _BaseGymSpaces):
return concatenate_base(items, out, space)
elif isinstance(space, Tuple):
return concatenate_tuple(items, out, space)
elif isinstance(space, Dict):
return concatenate_dict(items, out, space)
elif isinstance(space, Space):
return concatenate_custom(items, out, space)
else:
raise ValueError(
f"Space of type `{type(space)}` is not a valid `gym.Space` instance."
)
raise ValueError(
f"Space of type `{type(space)}` is not a valid `gym.Space` instance."
)
def concatenate_base(items, out, space):
@concatenate.register(Box)
@concatenate.register(Discrete)
@concatenate.register(MultiDiscrete)
@concatenate.register(MultiBinary)
def _concatenate_base(space, items, out):
return np.stack(items, axis=0, out=out)
def concatenate_tuple(items, out, space):
@concatenate.register(Tuple)
def _concatenate_tuple(space, items, out):
return tuple(
concatenate([item[i] for item in items], out[i], subspace)
concatenate(subspace, [item[i] for item in items], out[i])
for (i, subspace) in enumerate(space.spaces)
)
def concatenate_dict(items, out, space):
@concatenate.register(Dict)
def _concatenate_dict(space, items, out):
return OrderedDict(
[
(key, concatenate([item[key] for item in items], out[key], subspace))
(key, concatenate(subspace, [item[key] for item in items], out[key]))
for (key, subspace) in space.spaces.items()
]
)
def concatenate_custom(items, out, space):
@concatenate.register(Space)
def _concatenate_custom(space, items, out):
return tuple(items)
@singledispatch
def create_empty_array(space, n=1, fn=np.zeros):
"""Create an empty (possibly nested) numpy array.
@@ -108,30 +110,27 @@ def create_empty_array(space, n=1, fn=np.zeros):
('velocity', array([[0., 0.],
[0., 0.]], dtype=float32))])
"""
if isinstance(space, _BaseGymSpaces):
return create_empty_array_base(space, n=n, fn=fn)
elif isinstance(space, Tuple):
return create_empty_array_tuple(space, n=n, fn=fn)
elif isinstance(space, Dict):
return create_empty_array_dict(space, n=n, fn=fn)
elif isinstance(space, Space):
return create_empty_array_custom(space, n=n, fn=fn)
else:
raise ValueError(
f"Space of type `{type(space)}` is not a valid `gym.Space` instance."
)
raise ValueError(
f"Space of type `{type(space)}` is not a valid `gym.Space` instance."
)
def create_empty_array_base(space, n=1, fn=np.zeros):
@create_empty_array.register(Box)
@create_empty_array.register(Discrete)
@create_empty_array.register(MultiDiscrete)
@create_empty_array.register(MultiBinary)
def _create_empty_array_base(space, n=1, fn=np.zeros):
shape = space.shape if (n is None) else (n,) + space.shape
return fn(shape, dtype=space.dtype)
def create_empty_array_tuple(space, n=1, fn=np.zeros):
@create_empty_array.register(Tuple)
def _create_empty_array_tuple(space, n=1, fn=np.zeros):
return tuple(create_empty_array(subspace, n=n, fn=fn) for subspace in space.spaces)
def create_empty_array_dict(space, n=1, fn=np.zeros):
@create_empty_array.register(Dict)
def _create_empty_array_dict(space, n=1, fn=np.zeros):
return OrderedDict(
[
(key, create_empty_array(subspace, n=n, fn=fn))
@@ -140,5 +139,6 @@ def create_empty_array_dict(space, n=1, fn=np.zeros):
)
def create_empty_array_custom(space, n=1, fn=np.zeros):
@create_empty_array.register(Space)
def _create_empty_array_custom(space, n=1, fn=np.zeros):
return None