Files
Gymnasium/gym/vector/tests/test_shared_memory.py
Tristan Deleu c6a97e17ee Vectorized environments (#1513)
* Initial version of vectorized environments

* Raise an exception in the main process if child process raises an exception

* Add list of exposed functions in vector module

* Use deepcopy instead of np.copy

* Add documentation for vector utils

* Add tests for copy in AsyncVectorEnv

* Add example in documentation for batch_space

* Add cloudpickle dependency in setup.py

* Fix __del__ in VectorEnv

* Check if all observation spaces are equal in AsyncVectorEnv

* Check if all observation spaces are equal in SyncVectorEnv

* Fix spaces non equality in SyncVectorEnv for Python 2

* Handle None parameter in create_empty_array

* Fix check_observation_space with spaces equality

* Raise an exception when operations are out of order in AsyncVectorEnv

* Add version requirement for cloudpickle

* Use a state instead of binary flags in AsyncVectorEnv

* Use numpy.zeros when initializing observations in vectorized environments

* Remove poll from public API in AsyncVectorEnv

* Remove close_extras from VectorEnv

* Add test between AsyncVectorEnv and SyncVectorEnv

* Remove close in check_observation_space

* Add documentation for seed and close

* Refactor exceptions for AsyncVectorEnv

* Close pipes if the environment raises an error

* Add tests for out of order operations

* Change default argument in create_empty_array to np.zeros

* Add get_attr and set_attr methods to VectorEnv

* Improve consistency in SyncVectorEnv
2019-06-21 14:29:44 -07:00

138 lines
4.8 KiB
Python

import pytest
import numpy as np
from multiprocessing.sharedctypes import SynchronizedArray
from multiprocessing import Array, Process
from collections import OrderedDict
from gym.spaces import Tuple, Dict
from gym.vector.utils.spaces import _BaseGymSpaces
from gym.vector.tests.utils import spaces
from gym.vector.utils.shared_memory import (create_shared_memory,
read_from_shared_memory, write_to_shared_memory)
expected_types = [
Array('d', 1), Array('f', 1), Array('f', 3), Array('f', 4), Array('B', 1), Array('B', 32 * 32 * 3),
Array('i', 1), (Array('i', 1), Array('i', 1)), (Array('i', 1), Array('f', 2)),
Array('B', 3), Array('B', 19),
OrderedDict([
('position', Array('i', 1)),
('velocity', Array('f', 1))
]),
OrderedDict([
('position', OrderedDict([('x', Array('i', 1)), ('y', Array('i', 1))])),
('velocity', (Array('i', 1), Array('B', 1)))
])
]
@pytest.mark.parametrize('n', [1, 8])
@pytest.mark.parametrize('space,expected_type', list(zip(spaces, expected_types)),
ids=[space.__class__.__name__ for space in spaces])
def test_create_shared_memory(space, expected_type, n):
def assert_nested_type(lhs, rhs, n):
assert type(lhs) == type(rhs)
if isinstance(lhs, (list, tuple)):
assert len(lhs) == len(rhs)
for lhs_, rhs_ in zip(lhs, rhs):
assert_nested_type(lhs_, rhs_, n)
elif isinstance(lhs, (dict, OrderedDict)):
assert set(lhs.keys()) ^ set(rhs.keys()) == set()
for key in lhs.keys():
assert_nested_type(lhs[key], rhs[key], n)
elif isinstance(lhs, SynchronizedArray):
# Assert the length of the array
assert len(lhs[:]) == n * len(rhs[:])
# Assert the data type
assert type(lhs[0]) == type(rhs[0])
else:
raise TypeError('Got unknown type `{0}`.'.format(type(lhs)))
shared_memory = create_shared_memory(space, n=n)
assert_nested_type(shared_memory, expected_type, n=n)
@pytest.mark.parametrize('space', spaces,
ids=[space.__class__.__name__ for space in spaces])
def test_write_to_shared_memory(space):
def assert_nested_equal(lhs, rhs):
assert isinstance(rhs, list)
if isinstance(lhs, (list, tuple)):
for i in range(len(lhs)):
assert_nested_equal(lhs[i], [rhs_[i] for rhs_ in rhs])
elif isinstance(lhs, (dict, OrderedDict)):
for key in lhs.keys():
assert_nested_equal(lhs[key], [rhs_[key] for rhs_ in rhs])
elif isinstance(lhs, SynchronizedArray):
assert np.all(np.array(lhs[:]) == np.stack(rhs, axis=0).flatten())
else:
raise TypeError('Got unknown type `{0}`.'.format(type(lhs)))
def write(i, shared_memory, sample):
write_to_shared_memory(i, sample, shared_memory, space)
shared_memory_n8 = create_shared_memory(space, n=8)
samples = [space.sample() for _ in range(8)]
processes = [Process(target=write, args=(i, shared_memory_n8,
samples[i])) for i in range(8)]
for process in processes:
process.start()
for process in processes:
process.join()
assert_nested_equal(shared_memory_n8, samples)
@pytest.mark.parametrize('space', spaces,
ids=[space.__class__.__name__ for space in spaces])
def test_read_from_shared_memory(space):
def assert_nested_equal(lhs, rhs, space, n):
assert isinstance(rhs, list)
if isinstance(space, Tuple):
assert isinstance(lhs, tuple)
for i in range(len(lhs)):
assert_nested_equal(lhs[i], [rhs_[i] for rhs_ in rhs],
space.spaces[i], n)
elif isinstance(space, Dict):
assert isinstance(lhs, OrderedDict)
for key in lhs.keys():
assert_nested_equal(lhs[key], [rhs_[key] for rhs_ in rhs],
space.spaces[key], n)
elif isinstance(space, _BaseGymSpaces):
assert isinstance(lhs, np.ndarray)
assert lhs.shape == ((n,) + space.shape)
assert lhs.dtype == space.dtype
assert np.all(lhs == np.stack(rhs, axis=0))
else:
raise TypeError('Got unknown type `{0}`'.format(type(space)))
def write(i, shared_memory, sample):
write_to_shared_memory(i, sample, shared_memory, space)
shared_memory_n8 = create_shared_memory(space, n=8)
memory_view_n8 = read_from_shared_memory(shared_memory_n8, space, n=8)
samples = [space.sample() for _ in range(8)]
processes = [Process(target=write, args=(i, shared_memory_n8,
samples[i])) for i in range(8)]
for process in processes:
process.start()
for process in processes:
process.join()
assert_nested_equal(memory_view_n8, samples, space, n=8)