mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-18 21:06:59 +00:00
Added singledispatch utility to vector.utils & changed order of space argument. (#2536)
* Fixed ordering of space. Added singledispatch utility. * Added singledispatch utility to vector.utils & changed order of space argument * Fixed Error from _BaseGymSpaces * Minor adjustment for Discrete Spaces * Fixed Tests/ to reflect changes * Fixed precommit error - custom namespaces * Concrete Implementations start with _
This commit is contained in:
@@ -4,13 +4,16 @@ from ctypes import c_bool
|
||||
from collections import OrderedDict
|
||||
|
||||
from gym import logger
|
||||
from gym.spaces import Tuple, Dict
|
||||
from gym.spaces import Space, Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict
|
||||
from gym.error import CustomSpaceError
|
||||
from gym.vector.utils.spaces import _BaseGymSpaces
|
||||
|
||||
from functools import singledispatch
|
||||
|
||||
__all__ = ["create_shared_memory", "read_from_shared_memory", "write_to_shared_memory"]
|
||||
|
||||
|
||||
@singledispatch
|
||||
def create_shared_memory(space, n=1, ctx=mp):
|
||||
"""Create a shared memory object, to be shared across processes. This
|
||||
eventually contains the observations from the vectorized environment.
|
||||
@@ -32,36 +35,35 @@ def create_shared_memory(space, n=1, ctx=mp):
|
||||
shared_memory : dict, tuple, or `multiprocessing.Array` instance
|
||||
Shared object across processes.
|
||||
"""
|
||||
if isinstance(space, _BaseGymSpaces):
|
||||
return create_base_shared_memory(space, n=n, ctx=ctx)
|
||||
elif isinstance(space, Tuple):
|
||||
return create_tuple_shared_memory(space, n=n, ctx=ctx)
|
||||
elif isinstance(space, Dict):
|
||||
return create_dict_shared_memory(space, n=n, ctx=ctx)
|
||||
else:
|
||||
raise CustomSpaceError(
|
||||
"Cannot create a shared memory for space with "
|
||||
"type `{}`. Shared memory only supports "
|
||||
"default Gym spaces (e.g. `Box`, `Tuple`, "
|
||||
"`Dict`, etc...), and does not support custom "
|
||||
"Gym spaces.".format(type(space))
|
||||
)
|
||||
raise CustomSpaceError(
|
||||
"Cannot create a shared memory for space with "
|
||||
"type `{}`. Shared memory only supports "
|
||||
"default Gym spaces (e.g. `Box`, `Tuple`, "
|
||||
"`Dict`, etc...), and does not support custom "
|
||||
"Gym spaces.".format(type(space))
|
||||
)
|
||||
|
||||
|
||||
def create_base_shared_memory(space, n=1, ctx=mp):
|
||||
@create_shared_memory.register(Box)
|
||||
@create_shared_memory.register(Discrete)
|
||||
@create_shared_memory.register(MultiDiscrete)
|
||||
@create_shared_memory.register(MultiBinary)
|
||||
def _create_base_shared_memory(space, n=1, ctx=mp):
|
||||
dtype = space.dtype.char
|
||||
if dtype in "?":
|
||||
dtype = c_bool
|
||||
return ctx.Array(dtype, n * int(np.prod(space.shape)))
|
||||
|
||||
|
||||
def create_tuple_shared_memory(space, n=1, ctx=mp):
|
||||
@create_shared_memory.register(Tuple)
|
||||
def _create_tuple_shared_memory(space, n=1, ctx=mp):
|
||||
return tuple(
|
||||
create_shared_memory(subspace, n=n, ctx=ctx) for subspace in space.spaces
|
||||
)
|
||||
|
||||
|
||||
def create_dict_shared_memory(space, n=1, ctx=mp):
|
||||
@create_shared_memory.register(Dict)
|
||||
def _create_dict_shared_memory(space, n=1, ctx=mp):
|
||||
return OrderedDict(
|
||||
[
|
||||
(key, create_shared_memory(subspace, n=n, ctx=ctx))
|
||||
@@ -70,7 +72,8 @@ def create_dict_shared_memory(space, n=1, ctx=mp):
|
||||
)
|
||||
|
||||
|
||||
def read_from_shared_memory(shared_memory, space, n=1):
|
||||
@singledispatch
|
||||
def read_from_shared_memory(space, shared_memory, n=1):
|
||||
"""Read the batch of observations from shared memory as a numpy array.
|
||||
|
||||
Parameters
|
||||
@@ -97,45 +100,45 @@ def read_from_shared_memory(shared_memory, space, n=1):
|
||||
memory of `shared_memory`. Any changes to `shared_memory` are forwarded
|
||||
to `observations`, and vice-versa. To avoid any side-effect, use `np.copy`.
|
||||
"""
|
||||
if isinstance(space, _BaseGymSpaces):
|
||||
return read_base_from_shared_memory(shared_memory, space, n=n)
|
||||
elif isinstance(space, Tuple):
|
||||
return read_tuple_from_shared_memory(shared_memory, space, n=n)
|
||||
elif isinstance(space, Dict):
|
||||
return read_dict_from_shared_memory(shared_memory, space, n=n)
|
||||
else:
|
||||
raise CustomSpaceError(
|
||||
"Cannot read from a shared memory for space with "
|
||||
"type `{}`. Shared memory only supports "
|
||||
"default Gym spaces (e.g. `Box`, `Tuple`, "
|
||||
"`Dict`, etc...), and does not support custom "
|
||||
"Gym spaces.".format(type(space))
|
||||
)
|
||||
raise CustomSpaceError(
|
||||
"Cannot read from a shared memory for space with "
|
||||
"type `{}`. Shared memory only supports "
|
||||
"default Gym spaces (e.g. `Box`, `Tuple`, "
|
||||
"`Dict`, etc...), and does not support custom "
|
||||
"Gym spaces.".format(type(space))
|
||||
)
|
||||
|
||||
|
||||
def read_base_from_shared_memory(shared_memory, space, n=1):
|
||||
@read_from_shared_memory.register(Box)
|
||||
@read_from_shared_memory.register(Discrete)
|
||||
@read_from_shared_memory.register(MultiDiscrete)
|
||||
@read_from_shared_memory.register(MultiBinary)
|
||||
def _read_base_from_shared_memory(space, shared_memory, n=1):
|
||||
return np.frombuffer(shared_memory.get_obj(), dtype=space.dtype).reshape(
|
||||
(n,) + space.shape
|
||||
)
|
||||
|
||||
|
||||
def read_tuple_from_shared_memory(shared_memory, space, n=1):
|
||||
@read_from_shared_memory.register(Tuple)
|
||||
def _read_tuple_from_shared_memory(space, shared_memory, n=1):
|
||||
return tuple(
|
||||
read_from_shared_memory(memory, subspace, n=n)
|
||||
read_from_shared_memory(subspace, memory, n=n)
|
||||
for (memory, subspace) in zip(shared_memory, space.spaces)
|
||||
)
|
||||
|
||||
|
||||
def read_dict_from_shared_memory(shared_memory, space, n=1):
|
||||
@read_from_shared_memory.register(Dict)
|
||||
def _read_dict_from_shared_memory(space, shared_memory, n=1):
|
||||
return OrderedDict(
|
||||
[
|
||||
(key, read_from_shared_memory(shared_memory[key], subspace, n=n))
|
||||
(key, read_from_shared_memory(subspace, shared_memory[key], n=n))
|
||||
for (key, subspace) in space.spaces.items()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def write_to_shared_memory(index, value, shared_memory, space):
|
||||
@singledispatch
|
||||
def write_to_shared_memory(space, index, value, shared_memory):
|
||||
"""Write the observation of a single environment into shared memory.
|
||||
|
||||
Parameters
|
||||
@@ -157,23 +160,20 @@ def write_to_shared_memory(index, value, shared_memory, space):
|
||||
-------
|
||||
`None`
|
||||
"""
|
||||
if isinstance(space, _BaseGymSpaces):
|
||||
write_base_to_shared_memory(index, value, shared_memory, space)
|
||||
elif isinstance(space, Tuple):
|
||||
write_tuple_to_shared_memory(index, value, shared_memory, space)
|
||||
elif isinstance(space, Dict):
|
||||
write_dict_to_shared_memory(index, value, shared_memory, space)
|
||||
else:
|
||||
raise CustomSpaceError(
|
||||
"Cannot write to a shared memory for space with "
|
||||
"type `{}`. Shared memory only supports "
|
||||
"default Gym spaces (e.g. `Box`, `Tuple`, "
|
||||
"`Dict`, etc...), and does not support custom "
|
||||
"Gym spaces.".format(type(space))
|
||||
)
|
||||
raise CustomSpaceError(
|
||||
"Cannot write to a shared memory for space with "
|
||||
"type `{}`. Shared memory only supports "
|
||||
"default Gym spaces (e.g. `Box`, `Tuple`, "
|
||||
"`Dict`, etc...), and does not support custom "
|
||||
"Gym spaces.".format(type(space))
|
||||
)
|
||||
|
||||
|
||||
def write_base_to_shared_memory(index, value, shared_memory, space):
|
||||
@write_to_shared_memory.register(Box)
|
||||
@write_to_shared_memory.register(Discrete)
|
||||
@write_to_shared_memory.register(MultiDiscrete)
|
||||
@write_to_shared_memory.register(MultiBinary)
|
||||
def _write_base_to_shared_memory(space, index, value, shared_memory):
|
||||
size = int(np.prod(space.shape))
|
||||
destination = np.frombuffer(shared_memory.get_obj(), dtype=space.dtype)
|
||||
np.copyto(
|
||||
@@ -182,11 +182,13 @@ def write_base_to_shared_memory(index, value, shared_memory, space):
|
||||
)
|
||||
|
||||
|
||||
def write_tuple_to_shared_memory(index, values, shared_memory, space):
|
||||
@write_to_shared_memory.register(Tuple)
|
||||
def _write_tuple_to_shared_memory(space, index, values, shared_memory):
|
||||
for value, memory, subspace in zip(values, shared_memory, space.spaces):
|
||||
write_to_shared_memory(index, value, memory, subspace)
|
||||
write_to_shared_memory(subspace, index, value, memory)
|
||||
|
||||
|
||||
def write_dict_to_shared_memory(index, values, shared_memory, space):
|
||||
@write_to_shared_memory.register(Dict)
|
||||
def _write_dict_to_shared_memory(space, index, values, shared_memory):
|
||||
for key, subspace in space.spaces.items():
|
||||
write_to_shared_memory(index, values[key], shared_memory[key], subspace)
|
||||
write_to_shared_memory(subspace, index, values[key], shared_memory[key])
|
||||
|
Reference in New Issue
Block a user