Added singledispatch utility to vector.utils & changed order of space argument. (#2536)

* Fixed ordering of space. Added singledispatch utility.

* Added singledispatch utility to vector.utils & changed order of space argument

* Fixed Error from _BaseGymSpaces

* Minor adjustment for Discrete Spaces

* Fixed Tests/ to reflect changes

* Fixed precommit error - custom namespaces

* Concrete Implementations start with _
This commit is contained in:
Rushiv Arora
2022-01-21 11:28:34 -05:00
committed by GitHub
parent 925823661d
commit fcbff7de12
8 changed files with 139 additions and 136 deletions

View File

@@ -4,13 +4,16 @@ from ctypes import c_bool
from collections import OrderedDict
from gym import logger
from gym.spaces import Tuple, Dict
from gym.spaces import Space, Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict
from gym.error import CustomSpaceError
from gym.vector.utils.spaces import _BaseGymSpaces
from functools import singledispatch
__all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_memory"]
@singledispatch
def create_shared_memory(space, n=1, ctx=mp):
"""Create a shared memory object, to be shared across processes. This
eventually contains the observations from the vectorized environment.
@@ -32,36 +35,35 @@ def create_shared_memory(space, n=1, ctx=mp):
shared_memory : dict, tuple, or `multiprocessing.Array` instance
Shared object across processes.
"""
if isinstance(space, _BaseGymSpaces):
return create_base_shared_memory(space, n=n, ctx=ctx)
elif isinstance(space, Tuple):
return create_tuple_shared_memory(space, n=n, ctx=ctx)
elif isinstance(space, Dict):
return create_dict_shared_memory(space, n=n, ctx=ctx)
else:
raise CustomSpaceError(
"Cannot create a shared memory for space with "
"type `{}`. Shared memory only supports "
"default Gym spaces (e.g. `Box`, `Tuple`, "
"`Dict`, etc...), and does not support custom "
"Gym spaces.".format(type(space))
)
raise CustomSpaceError(
"Cannot create a shared memory for space with "
"type `{}`. Shared memory only supports "
"default Gym spaces (e.g. `Box`, `Tuple`, "
"`Dict`, etc...), and does not support custom "
"Gym spaces.".format(type(space))
)
def create_base_shared_memory(space, n=1, ctx=mp):
@create_shared_memory.register(Box)
@create_shared_memory.register(Discrete)
@create_shared_memory.register(MultiDiscrete)
@create_shared_memory.register(MultiBinary)
def _create_base_shared_memory(space, n=1, ctx=mp):
dtype = space.dtype.char
if dtype in "?":
dtype = c_bool
return ctx.Array(dtype, n * int(np.prod(space.shape)))
def create_tuple_shared_memory(space, n=1, ctx=mp):
@create_shared_memory.register(Tuple)
def _create_tuple_shared_memory(space, n=1, ctx=mp):
return tuple(
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
)
def create_dict_shared_memory(space, n=1, ctx=mp):
@create_shared_memory.register(Dict)
def _create_dict_shared_memory(space, n=1, ctx=mp):
return OrderedDict(
[
(key, create_shared_memory(subspace, n=n, ctx=ctx))
@@ -70,7 +72,8 @@ def create_dict_shared_memory(space, n=1, ctx=mp):
)
def read_from_shared_memory(shared_memory, space, n=1):
@singledispatch
def read_from_shared_memory(space, shared_memory, n=1):
"""Read the batch of observations from shared memory as a numpy array.
Parameters
@@ -97,45 +100,45 @@ def read_from_shared_memory(shared_memory, space, n=1):
memory of `shared_memory`. Any changes to `shared_memory` are forwarded
to `observations`, and vice-versa. To avoid any side-effect, use `np.copy`.
"""
if isinstance(space, _BaseGymSpaces):
return read_base_from_shared_memory(shared_memory, space, n=n)
elif isinstance(space, Tuple):
return read_tuple_from_shared_memory(shared_memory, space, n=n)
elif isinstance(space, Dict):
return read_dict_from_shared_memory(shared_memory, space, n=n)
else:
raise CustomSpaceError(
"Cannot read from a shared memory for space with "
"type `{}`. Shared memory only supports "
"default Gym spaces (e.g. `Box`, `Tuple`, "
"`Dict`, etc...), and does not support custom "
"Gym spaces.".format(type(space))
)
raise CustomSpaceError(
"Cannot read from a shared memory for space with "
"type `{}`. Shared memory only supports "
"default Gym spaces (e.g. `Box`, `Tuple`, "
"`Dict`, etc...), and does not support custom "
"Gym spaces.".format(type(space))
)
def read_base_from_shared_memory(shared_memory, space, n=1):
@read_from_shared_memory.register(Box)
@read_from_shared_memory.register(Discrete)
@read_from_shared_memory.register(MultiDiscrete)
@read_from_shared_memory.register(MultiBinary)
def _read_base_from_shared_memory(space, shared_memory, n=1):
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape(
(n,) + space.shape
)
def read_tuple_from_shared_memory(shared_memory, space, n=1):
@read_from_shared_memory.register(Tuple)
def _read_tuple_from_shared_memory(space, shared_memory, n=1):
return tuple(
read_from_shared_memory(memory, subspace, n=n)
read_from_shared_memory(subspace, memory, n=n)
for (memory, subspace) in zip(shared_memory, space.spaces)
)
def read_dict_from_shared_memory(shared_memory, space, n=1):
@read_from_shared_memory.register(Dict)
def _read_dict_from_shared_memory(space, shared_memory, n=1):
return OrderedDict(
[
(key, read_from_shared_memory(shared_memory[key], subspace, n=n))
(key, read_from_shared_memory(subspace, shared_memory[key], n=n))
for (key, subspace) in space.spaces.items()
]
)
def write_to_shared_memory(index, value, shared_memory, space):
@singledispatch
def write_to_shared_memory(space, index, value, shared_memory):
"""Write the observation of a single environment into shared memory.
Parameters
@@ -157,23 +160,20 @@ def write_to_shared_memory(index, value, shared_memory, space):
-------
`None`
"""
if isinstance(space, _BaseGymSpaces):
write_base_to_shared_memory(index, value, shared_memory, space)
elif isinstance(space, Tuple):
write_tuple_to_shared_memory(index, value, shared_memory, space)
elif isinstance(space, Dict):
write_dict_to_shared_memory(index, value, shared_memory, space)
else:
raise CustomSpaceError(
"Cannot write to a shared memory for space with "
"type `{}`. Shared memory only supports "
"default Gym spaces (e.g. `Box`, `Tuple`, "
"`Dict`, etc...), and does not support custom "
"Gym spaces.".format(type(space))
)
raise CustomSpaceError(
"Cannot write to a shared memory for space with "
"type `{}`. Shared memory only supports "
"default Gym spaces (e.g. `Box`, `Tuple`, "
"`Dict`, etc...), and does not support custom "
"Gym spaces.".format(type(space))
)
def write_base_to_shared_memory(index, value, shared_memory, space):
@write_to_shared_memory.register(Box)
@write_to_shared_memory.register(Discrete)
@write_to_shared_memory.register(MultiDiscrete)
@write_to_shared_memory.register(MultiBinary)
def _write_base_to_shared_memory(space, index, value, shared_memory):
size = int(np.prod(space.shape))
destination = np.frombuffer(shared_memory.get_obj(), dtype=space.dtype)
np.copyto(
@@ -182,11 +182,13 @@ def write_base_to_shared_memory(index, value, shared_memory, space):
)
def write_tuple_to_shared_memory(index, values, shared_memory, space):
@write_to_shared_memory.register(Tuple)
def _write_tuple_to_shared_memory(space, index, values, shared_memory):
for value, memory, subspace in zip(values, shared_memory, space.spaces):
write_to_shared_memory(index, value, memory, subspace)
write_to_shared_memory(subspace, index, value, memory)
def write_dict_to_shared_memory(index, values, shared_memory, space):
@write_to_shared_memory.register(Dict)
def _write_dict_to_shared_memory(space, index, values, shared_memory):
for key, subspace in space.spaces.items():
write_to_shared_memory(index, values[key], shared_memory[key], subspace)
write_to_shared_memory(subspace, index, values[key], shared_memory[key])