mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-20 05:52:03 +00:00
Fix reading shared memory for Tuple
and Dict
spaces (#941)
This commit is contained in:
@@ -148,20 +148,25 @@ def _read_base_from_shared_memory(
|
|||||||
|
|
||||||
@read_from_shared_memory.register(Tuple)
|
@read_from_shared_memory.register(Tuple)
|
||||||
def _read_tuple_from_shared_memory(space: Tuple, shared_memory, n: int = 1):
|
def _read_tuple_from_shared_memory(space: Tuple, shared_memory, n: int = 1):
|
||||||
return tuple(
|
subspace_samples = tuple(
|
||||||
read_from_shared_memory(subspace, memory, n=n)
|
read_from_shared_memory(subspace, memory, n=n)
|
||||||
for (memory, subspace) in zip(shared_memory, space.spaces)
|
for (memory, subspace) in zip(shared_memory, space.spaces)
|
||||||
)
|
)
|
||||||
|
return tuple(zip(*subspace_samples))
|
||||||
|
|
||||||
|
|
||||||
@read_from_shared_memory.register(Dict)
|
@read_from_shared_memory.register(Dict)
|
||||||
def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1):
|
def _read_dict_from_shared_memory(space: Dict, shared_memory, n: int = 1):
|
||||||
return OrderedDict(
|
subspace_samples = OrderedDict(
|
||||||
[
|
[
|
||||||
(key, read_from_shared_memory(subspace, shared_memory[key], n=n))
|
(key, read_from_shared_memory(subspace, shared_memory[key], n=n))
|
||||||
for (key, subspace) in space.spaces.items()
|
for (key, subspace) in space.spaces.items()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
return tuple(
|
||||||
|
OrderedDict({key: subspace_samples[key][i] for key in space.keys()})
|
||||||
|
for i in range(n)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@read_from_shared_memory.register(Text)
|
@read_from_shared_memory.register(Text)
|
||||||
|
@@ -22,7 +22,7 @@ from tests.spaces.utils import TESTING_SPACES, TESTING_SPACES_IDS
|
|||||||
"ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"]
|
"ctx", [None, "fork", "spawn"], ids=["default", "fork", "spawn"]
|
||||||
)
|
)
|
||||||
def test_shared_memory_create_read_write(space, num, ctx):
|
def test_shared_memory_create_read_write(space, num, ctx):
|
||||||
"""Test the shared memory functions, create, read and write for all of the testing spaces."""
|
"""Test the shared memory functions, create, read and write for all testing spaces."""
|
||||||
if ctx not in mp.get_all_start_methods():
|
if ctx not in mp.get_all_start_methods():
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
f"Multiprocessing start method {ctx} not available on this platform."
|
f"Multiprocessing start method {ctx} not available on this platform."
|
||||||
@@ -41,7 +41,7 @@ def test_shared_memory_create_read_write(space, num, ctx):
|
|||||||
|
|
||||||
read_samples = read_from_shared_memory(space, shared_memory, n=num)
|
read_samples = read_from_shared_memory(space, shared_memory, n=num)
|
||||||
for read_sample, sample in zip(read_samples, samples):
|
for read_sample, sample in zip(read_samples, samples):
|
||||||
data_equivalence(read_sample, sample)
|
assert data_equivalence(read_sample, sample)
|
||||||
|
|
||||||
|
|
||||||
def test_custom_space():
|
def test_custom_space():
|
||||||
|
Reference in New Issue
Block a user