mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 06:07:08 +00:00
Vectorized environments (#1513)
* Initial version of vectorized environments * Raise an exception in the main process if child process raises an exception * Add list of exposed functions in vector module * Use deepcopy instead of np.copy * Add documentation for vector utils * Add tests for copy in AsyncVectorEnv * Add example in documentation for batch_space * Add cloudpickle dependency in setup.py * Fix __del__ in VectorEnv * Check if all observation spaces are equal in AsyncVectorEnv * Check if all observation spaces are equal in SyncVectorEnv * Fix spaces non equality in SyncVectorEnv for Python 2 * Handle None parameter in create_empty_array * Fix check_observation_space with spaces equality * Raise an exception when operations are out of order in AsyncVectorEnv * Add version requirement for cloudpickle * Use a state instead of binary flags in AsyncVectorEnv * Use numpy.zeros when initializing observations in vectorized environments * Remove poll from public API in AsyncVectorEnv * Remove close_extras from VectorEnv * Add test between AsyncVectorEnv and SyncVectorEnv * Remove close in check_observation_space * Add documentation for seed and close * Refactor exceptions for AsyncVectorEnv * Close pipes if the environment raises an error * Add tests for out of order operations * Change default argument in create_empty_array to np.zeros * Add get_attr and set_attr methods to VectorEnv * Improve consistency in SyncVectorEnv
This commit is contained in:
@@ -10,5 +10,6 @@ from gym.core import Env, GoalEnv, Wrapper, ObservationWrapper, ActionWrapper, R
|
||||
from gym.spaces import Space
|
||||
from gym.envs import make, spec, register
|
||||
from gym import logger
|
||||
from gym import vector
|
||||
|
||||
__all__ = ["Env", "Space", "Wrapper", "make", "spec", "register"]
|
||||
|
28
gym/error.py
28
gym/error.py
@@ -137,3 +137,31 @@ class WrapAfterConfigureError(Error):
|
||||
|
||||
class RetriesExceededError(Error):
|
||||
pass
|
||||
|
||||
# Vectorized environments errors
|
||||
|
||||
class AlreadyPendingCallError(Exception):
|
||||
"""
|
||||
Raised when `reset`, or `step` is called asynchronously (e.g. with
|
||||
`reset_async`, or `step_async` respectively), and `reset_async`, or
|
||||
`step_async` (respectively) is called again (without a complete call to
|
||||
`reset_wait`, or `step_wait` respectively).
|
||||
"""
|
||||
def __init__(self, message, name):
|
||||
super(AlreadyPendingCallError, self).__init__(message)
|
||||
self.name = name
|
||||
|
||||
class NoAsyncCallError(Exception):
|
||||
"""
|
||||
Raised when an asynchronous `reset`, or `step` is not running, but
|
||||
`reset_wait`, or `step_wait` (respectively) is called.
|
||||
"""
|
||||
def __init__(self, message, name):
|
||||
super(NoAsyncCallError, self).__init__(message)
|
||||
self.name = name
|
||||
|
||||
class ClosedEnvironmentError(Exception):
|
||||
"""
|
||||
Trying to call `reset`, or `step`, while the environment is closed.
|
||||
"""
|
||||
pass
|
||||
|
47
gym/vector/__init__.py
Normal file
47
gym/vector/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from gym.vector.async_vector_env import AsyncVectorEnv
|
||||
from gym.vector.sync_vector_env import SyncVectorEnv
|
||||
from gym.vector.vector_env import VectorEnv
|
||||
|
||||
__all__ = ['AsyncVectorEnv', 'SyncVectorEnv', 'VectorEnv', 'make']
|
||||
|
||||
def make(id, num_envs=1, asynchronous=True, **kwargs):
|
||||
"""Create a vectorized environment from multiple copies of an environment,
|
||||
from its id
|
||||
|
||||
Parameters
|
||||
----------
|
||||
id : str
|
||||
The environment ID. This must be a valid ID from the registry.
|
||||
|
||||
num_envs : int
|
||||
Number of copies of the environment. If `1`, then it returns an
|
||||
unwrapped (i.e. non-vectorized) environment.
|
||||
|
||||
asynchronous : bool (default: `True`)
|
||||
If `True`, wraps the environments in an `AsyncVectorEnv` (which uses
|
||||
`multiprocessing` to run the environments in parallel). If `False`,
|
||||
wraps the environments in a `SyncVectorEnv`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
env : `gym.vector.VectorEnv` instance
|
||||
The vectorized environment.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> import gym
|
||||
>>> env = gym.vector.make('CartPole-v1', 3)
|
||||
>>> env.reset()
|
||||
array([[-0.04456399, 0.04653909, 0.01326909, -0.02099827],
|
||||
[ 0.03073904, 0.00145001, -0.03088818, -0.03131252],
|
||||
[ 0.03468829, 0.01500225, 0.01230312, 0.01825218]],
|
||||
dtype=float32)
|
||||
"""
|
||||
from gym.envs import make as make_
|
||||
def _make_env():
|
||||
return make_(id, **kwargs)
|
||||
if num_envs == 1:
|
||||
return _make_env()
|
||||
env_fns = [_make_env for _ in range(num_envs)]
|
||||
|
||||
return AsyncVectorEnv(env_fns) if asynchronous else SyncVectorEnv(env_fns)
|
405
gym/vector/async_vector_env.py
Normal file
405
gym/vector/async_vector_env.py
Normal file
@@ -0,0 +1,405 @@
|
||||
import numpy as np
|
||||
import multiprocessing as mp
|
||||
import time
|
||||
import sys
|
||||
from enum import Enum
|
||||
from copy import deepcopy
|
||||
|
||||
from gym import logger
|
||||
from gym.vector.vector_env import VectorEnv
|
||||
from gym.error import (AlreadyPendingCallError, NoAsyncCallError,
|
||||
ClosedEnvironmentError)
|
||||
from gym.vector.utils import (create_shared_memory, create_empty_array,
|
||||
write_to_shared_memory, read_from_shared_memory,
|
||||
concatenate, CloudpickleWrapper, clear_mpi_env_vars)
|
||||
|
||||
__all__ = ['AsyncVectorEnv']
|
||||
|
||||
|
||||
class AsyncState(Enum):
|
||||
DEFAULT = 'default'
|
||||
WAITING_RESET = 'reset'
|
||||
WAITING_STEP = 'step'
|
||||
|
||||
|
||||
class AsyncVectorEnv(VectorEnv):
|
||||
"""Vectorized environment that runs multiple environments in parallel. It
|
||||
uses `multiprocessing` processes, and pipes for communication.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
env_fns : iterable of callable
|
||||
Functions that create the environments.
|
||||
|
||||
observation_space : `gym.spaces.Space` instance, optional
|
||||
Observation space of a single environment. If `None`, then the
|
||||
observation space of the first environment is taken.
|
||||
|
||||
action_space : `gym.spaces.Space` instance, optional
|
||||
Action space of a single environment. If `None`, then the action space
|
||||
of the first environment is taken.
|
||||
|
||||
shared_memory : bool (default: `True`)
|
||||
If `True`, then the observations from the worker processes are
|
||||
communicated back through shared variables. This can improve the
|
||||
efficiency if the observations are large (e.g. images).
|
||||
|
||||
copy : bool (default: `True`)
|
||||
If `True`, then the `reset` and `step` methods return a copy of the
|
||||
observations.
|
||||
|
||||
context : str, optional
|
||||
Context for multiprocessing. If `None`, then the default context is used.
|
||||
Only available in Python 3.
|
||||
"""
|
||||
def __init__(self, env_fns, observation_space=None, action_space=None,
|
||||
shared_memory=True, copy=True, context=None):
|
||||
try:
|
||||
ctx = mp.get_context(context)
|
||||
except AttributeError:
|
||||
logger.warn('Context switching for `multiprocessing` is not '
|
||||
'available in Python 2. Using the default context.')
|
||||
ctx = mp
|
||||
self.env_fns = env_fns
|
||||
self.shared_memory = shared_memory
|
||||
self.copy = copy
|
||||
|
||||
if (observation_space is None) or (action_space is None):
|
||||
dummy_env = env_fns[0]()
|
||||
observation_space = observation_space or dummy_env.observation_space
|
||||
action_space = action_space or dummy_env.action_space
|
||||
dummy_env.close()
|
||||
del dummy_env
|
||||
super(AsyncVectorEnv, self).__init__(num_envs=len(env_fns),
|
||||
observation_space=observation_space, action_space=action_space)
|
||||
|
||||
if self.shared_memory:
|
||||
_obs_buffer = create_shared_memory(self.single_observation_space,
|
||||
n=self.num_envs)
|
||||
self.observations = read_from_shared_memory(_obs_buffer,
|
||||
self.single_observation_space, n=self.num_envs)
|
||||
else:
|
||||
_obs_buffer = None
|
||||
self.observations = create_empty_array(
|
||||
self.single_observation_space, n=self.num_envs, fn=np.zeros)
|
||||
|
||||
self.parent_pipes, self.processes = [], []
|
||||
self.error_queue = ctx.Queue()
|
||||
target = _worker_shared_memory if self.shared_memory else _worker
|
||||
with clear_mpi_env_vars():
|
||||
for idx, env_fn in enumerate(self.env_fns):
|
||||
parent_pipe, child_pipe = ctx.Pipe()
|
||||
process = ctx.Process(target=target,
|
||||
name='Worker<{0}>-{1}'.format(type(self).__name__, idx),
|
||||
args=(idx, CloudpickleWrapper(env_fn), child_pipe,
|
||||
parent_pipe, _obs_buffer, self.error_queue))
|
||||
|
||||
self.parent_pipes.append(parent_pipe)
|
||||
self.processes.append(process)
|
||||
|
||||
process.deamon = True
|
||||
process.start()
|
||||
child_pipe.close()
|
||||
|
||||
self._state = AsyncState.DEFAULT
|
||||
self._check_observation_spaces()
|
||||
|
||||
def seed(self, seeds=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
seeds : list of int, or int, optional
|
||||
Random seed for each individual environment. If `seeds` is a list of
|
||||
length `num_envs`, then the items of the list are chosen as random
|
||||
seeds. If `seeds` is an int, then each environment uses the random
|
||||
seed `seeds + n`, where `n` is the index of the environment (between
|
||||
`0` and `num_envs - 1`).
|
||||
"""
|
||||
self._assert_is_running()
|
||||
if seeds is None:
|
||||
seeds = [None for _ in range(self.num_envs)]
|
||||
if isinstance(seeds, int):
|
||||
seeds = [seeds + i for i in range(self.num_envs)]
|
||||
assert len(seeds) == self.num_envs
|
||||
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
raise AlreadyPendingCallError('Calling `seed` while waiting '
|
||||
'for a pending call to `{0}` to complete.'.format(
|
||||
self._state.value), self._state.value)
|
||||
|
||||
for pipe, seed in zip(self.parent_pipes, seeds):
|
||||
pipe.send(('seed', seed))
|
||||
for pipe in self.parent_pipes:
|
||||
pipe.recv()
|
||||
|
||||
def reset_async(self):
|
||||
self._assert_is_running()
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
raise AlreadyPendingCallError('Calling `reset_async` while waiting '
|
||||
'for a pending call to `{0}` to complete'.format(
|
||||
self._state.value), self._state.value)
|
||||
|
||||
for pipe in self.parent_pipes:
|
||||
pipe.send(('reset', None))
|
||||
self._state = AsyncState.WAITING_RESET
|
||||
|
||||
def reset_wait(self, timeout=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
timeout : int or float, optional
|
||||
Number of seconds before the call to `reset_wait` times out. If
|
||||
`None`, the call to `reset_wait` never times out.
|
||||
|
||||
Returns
|
||||
-------
|
||||
observations : sample from `observation_space`
|
||||
A batch of observations from the vectorized environment.
|
||||
"""
|
||||
self._assert_is_running()
|
||||
if self._state != AsyncState.WAITING_RESET:
|
||||
raise NoAsyncCallError('Calling `reset_wait` without any prior '
|
||||
'call to `reset_async`.', AsyncState.WAITING_RESET.value)
|
||||
|
||||
if not self._poll(timeout):
|
||||
self._state = AsyncState.DEFAULT
|
||||
raise mp.TimeoutError('The call to `reset_wait` has timed out after '
|
||||
'{0} second{1}.'.format(timeout, 's' if timeout > 1 else ''))
|
||||
|
||||
self._raise_if_errors()
|
||||
observations_list = [pipe.recv() for pipe in self.parent_pipes]
|
||||
self._state = AsyncState.DEFAULT
|
||||
|
||||
if not self.shared_memory:
|
||||
concatenate(observations_list, self.observations,
|
||||
self.single_observation_space)
|
||||
|
||||
return deepcopy(self.observations) if self.copy else self.observations
|
||||
|
||||
def step_async(self, actions):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
actions : iterable of samples from `action_space`
|
||||
List of actions.
|
||||
"""
|
||||
self._assert_is_running()
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
raise AlreadyPendingCallError('Calling `step_async` while waiting '
|
||||
'for a pending call to `{0}` to complete.'.format(
|
||||
self._state.value), self._state.value)
|
||||
|
||||
for pipe, action in zip(self.parent_pipes, actions):
|
||||
pipe.send(('step', action))
|
||||
self._state = AsyncState.WAITING_STEP
|
||||
|
||||
def step_wait(self, timeout=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
timeout : int or float, optional
|
||||
Number of seconds before the call to `step_wait` times out. If
|
||||
`None`, the call to `step_wait` never times out.
|
||||
|
||||
Returns
|
||||
-------
|
||||
observations : sample from `observation_space`
|
||||
A batch of observations from the vectorized environment.
|
||||
|
||||
rewards : `np.ndarray` instance (dtype `np.float_`)
|
||||
A vector of rewards from the vectorized environment.
|
||||
|
||||
dones : `np.ndarray` instance (dtype `np.bool_`)
|
||||
A vector whose entries indicate whether the episode has ended.
|
||||
|
||||
infos : list of dict
|
||||
A list of auxiliary diagnostic informations.
|
||||
"""
|
||||
self._assert_is_running()
|
||||
if self._state != AsyncState.WAITING_STEP:
|
||||
raise NoAsyncCallError('Calling `step_wait` without any prior call '
|
||||
'to `step_async`.', AsyncState.WAITING_STEP.value)
|
||||
|
||||
if not self._poll(timeout):
|
||||
self._state = AsyncState.DEFAULT
|
||||
raise mp.TimeoutError('The call to `step_wait` has timed out after '
|
||||
'{0} second{1}.'.format(timeout, 's' if timeout > 1 else ''))
|
||||
|
||||
self._raise_if_errors()
|
||||
results = [pipe.recv() for pipe in self.parent_pipes]
|
||||
self._state = AsyncState.DEFAULT
|
||||
observations_list, rewards, dones, infos = zip(*results)
|
||||
|
||||
if not self.shared_memory:
|
||||
concatenate(observations_list, self.observations,
|
||||
self.single_observation_space)
|
||||
|
||||
return (deepcopy(self.observations) if self.copy else self.observations,
|
||||
np.array(rewards), np.array(dones, dtype=np.bool_), infos)
|
||||
|
||||
def close(self, timeout=None, terminate=False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
timeout : int or float, optional
|
||||
Number of seconds before the call to `close` times out. If `None`,
|
||||
the call to `close` never times out. If the call to `close` times
|
||||
out, then all processes are terminated.
|
||||
|
||||
terminate : bool (default: `False`)
|
||||
If `True`, then the `close` operation is forced and all processes
|
||||
are terminated.
|
||||
"""
|
||||
if self.closed:
|
||||
return
|
||||
|
||||
if self.viewer is not None:
|
||||
self.viewer.close()
|
||||
|
||||
timeout = 0 if terminate else timeout
|
||||
try:
|
||||
if self._state != AsyncState.DEFAULT:
|
||||
logger.warn('Calling `close` while waiting for a pending '
|
||||
'call to `{0}` to complete.'.format(self._state.value))
|
||||
function = getattr(self, '{0}_wait'.format(self._state.value))
|
||||
function(timeout)
|
||||
except mp.TimeoutError:
|
||||
terminate = True
|
||||
|
||||
if terminate:
|
||||
for process in self.processes:
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
else:
|
||||
for pipe in self.parent_pipes:
|
||||
if not pipe.closed:
|
||||
pipe.send(('close', None))
|
||||
for pipe in self.parent_pipes:
|
||||
if not pipe.closed:
|
||||
pipe.recv()
|
||||
|
||||
for pipe in self.parent_pipes:
|
||||
pipe.close()
|
||||
for process in self.processes:
|
||||
process.join()
|
||||
|
||||
self.closed = True
|
||||
|
||||
def _poll(self, timeout=None):
|
||||
self._assert_is_running()
|
||||
if timeout is not None:
|
||||
end_time = time.time() + timeout
|
||||
delta = None
|
||||
for pipe in self.parent_pipes:
|
||||
if timeout is not None:
|
||||
delta = max(end_time - time.time(), 0)
|
||||
if pipe.closed or (not pipe.poll(delta)):
|
||||
break
|
||||
else:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_observation_spaces(self):
|
||||
self._assert_is_running()
|
||||
for pipe in self.parent_pipes:
|
||||
pipe.send(('_check_observation_space', self.single_observation_space))
|
||||
if not all([pipe.recv() for pipe in self.parent_pipes]):
|
||||
raise RuntimeError('Some environments have an observation space '
|
||||
'different from `{0}`. In order to batch observations, the '
|
||||
'observation spaces from all environments must be '
|
||||
'equal.'.format(self.single_observation_space))
|
||||
|
||||
def _assert_is_running(self):
|
||||
if self.closed:
|
||||
raise ClosedEnvironmentError('Trying to operate on `{0}`, after a '
|
||||
'call to `close()`.'.format(type(self).__name__))
|
||||
|
||||
def _raise_if_errors(self):
|
||||
if not self.error_queue.empty():
|
||||
while not self.error_queue.empty():
|
||||
index, exctype, value = self.error_queue.get()
|
||||
logger.error('Received the following error from Worker-{0}: '
|
||||
'{1}: {2}'.format(index, exctype.__name__, value))
|
||||
logger.error('Shutting down Worker-{0}.'.format(index))
|
||||
self.parent_pipes[index].close()
|
||||
self.parent_pipes[index] = None
|
||||
logger.error('Raising the last exception back to the main process.')
|
||||
raise exctype(value)
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'closed'):
|
||||
if not self.closed:
|
||||
self.close(terminate=True)
|
||||
|
||||
|
||||
def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
||||
assert shared_memory is None
|
||||
env = env_fn()
|
||||
parent_pipe.close()
|
||||
try:
|
||||
while True:
|
||||
command, data = pipe.recv()
|
||||
if command == 'reset':
|
||||
observation = env.reset()
|
||||
pipe.send(observation)
|
||||
elif command == 'step':
|
||||
observation, reward, done, info = env.step(data)
|
||||
if done:
|
||||
observation = env.reset()
|
||||
pipe.send((observation, reward, done, info))
|
||||
elif command == 'seed':
|
||||
env.seed(data)
|
||||
pipe.send(None)
|
||||
elif command == 'close':
|
||||
pipe.send(None)
|
||||
break
|
||||
elif command == '_check_observation_space':
|
||||
pipe.send(data == env.observation_space)
|
||||
else:
|
||||
raise RuntimeError('Received unknown command `{0}`. Must '
|
||||
'be one of {`reset`, `step`, `seed`, `close`, '
|
||||
'`_check_observation_space`}.'.format(command))
|
||||
except Exception:
|
||||
error_queue.put((index,) + sys.exc_info()[:2])
|
||||
pipe.send(None)
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
|
||||
def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
||||
assert shared_memory is not None
|
||||
env = env_fn()
|
||||
observation_space = env.observation_space
|
||||
parent_pipe.close()
|
||||
try:
|
||||
while True:
|
||||
command, data = pipe.recv()
|
||||
if command == 'reset':
|
||||
observation = env.reset()
|
||||
write_to_shared_memory(index, observation, shared_memory,
|
||||
observation_space)
|
||||
pipe.send(None)
|
||||
elif command == 'step':
|
||||
observation, reward, done, info = env.step(data)
|
||||
if done:
|
||||
observation = env.reset()
|
||||
write_to_shared_memory(index, observation, shared_memory,
|
||||
observation_space)
|
||||
pipe.send((None, reward, done, info))
|
||||
elif command == 'seed':
|
||||
env.seed(data)
|
||||
pipe.send(None)
|
||||
elif command == 'close':
|
||||
pipe.send(None)
|
||||
break
|
||||
elif command == '_check_observation_space':
|
||||
pipe.send(data == observation_space)
|
||||
else:
|
||||
raise RuntimeError('Received unknown command `{0}`. Must '
|
||||
'be one of {`reset`, `step`, `seed`, `close`, '
|
||||
'`_check_observation_space`}.'.format(command))
|
||||
except Exception:
|
||||
error_queue.put((index,) + sys.exc_info()[:2])
|
||||
pipe.send(None)
|
||||
finally:
|
||||
env.close()
|
137
gym/vector/sync_vector_env.py
Normal file
137
gym/vector/sync_vector_env.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import numpy as np
|
||||
|
||||
from gym import logger
|
||||
from gym.vector.vector_env import VectorEnv
|
||||
from gym.vector.utils import concatenate, create_empty_array
|
||||
|
||||
__all__ = ['SyncVectorEnv']
|
||||
|
||||
|
||||
class SyncVectorEnv(VectorEnv):
|
||||
"""Vectorized environment that serially runs multiple environments.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
env_fns : iterable of callable
|
||||
Functions that create the environments.
|
||||
|
||||
observation_space : `gym.spaces.Space` instance, optional
|
||||
Observation space of a single environment. If `None`, then the
|
||||
observation space of the first environment is taken.
|
||||
|
||||
action_space : `gym.spaces.Space` instance, optional
|
||||
Action space of a single environment. If `None`, then the action space
|
||||
of the first environment is taken.
|
||||
|
||||
copy : bool (default: `True`)
|
||||
If `True`, then the `reset` and `step` methods return a copy of the
|
||||
observations.
|
||||
"""
|
||||
def __init__(self, env_fns, observation_space=None, action_space=None,
|
||||
copy=True):
|
||||
self.env_fns = env_fns
|
||||
self.envs = [env_fn() for env_fn in env_fns]
|
||||
self.copy = copy
|
||||
|
||||
if (observation_space is None) or (action_space is None):
|
||||
observation_space = observation_space or self.envs[0].observation_space
|
||||
action_space = action_space or self.envs[0].action_space
|
||||
super(SyncVectorEnv, self).__init__(num_envs=len(env_fns),
|
||||
observation_space=observation_space, action_space=action_space)
|
||||
|
||||
self._check_observation_spaces()
|
||||
self.observations = create_empty_array(self.single_observation_space,
|
||||
n=self.num_envs, fn=np.zeros)
|
||||
self._rewards = np.zeros((self.num_envs,), dtype=np.float64)
|
||||
self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
|
||||
|
||||
def seed(self, seeds=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
seeds : list of int, or int, optional
|
||||
Random seed for each individual environment. If `seeds` is a list of
|
||||
length `num_envs`, then the items of the list are chosen as random
|
||||
seeds. If `seeds` is an int, then each environment uses the random
|
||||
seed `seeds + n`, where `n` is the index of the environment (between
|
||||
`0` and `num_envs - 1`).
|
||||
"""
|
||||
if seeds is None:
|
||||
seeds = [None for _ in range(self.num_envs)]
|
||||
if isinstance(seeds, int):
|
||||
seeds = [seeds + i for i in range(self.num_envs)]
|
||||
assert len(seeds) == self.num_envs
|
||||
|
||||
for env, seed in zip(self.envs, seeds):
|
||||
env.seed(seed)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
observations : sample from `observation_space`
|
||||
A batch of observations from the vectorized environment.
|
||||
"""
|
||||
self._dones[:] = False
|
||||
observations = []
|
||||
for env in self.envs:
|
||||
observation = env.reset()
|
||||
observations.append(observation)
|
||||
concatenate(observations, self.observations, self.single_observation_space)
|
||||
|
||||
return np.copy(self.observations) if self.copy else self.observations
|
||||
|
||||
def step(self, actions):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
actions : iterable of samples from `action_space`
|
||||
List of actions.
|
||||
|
||||
Returns
|
||||
-------
|
||||
observations : sample from `observation_space`
|
||||
A batch of observations from the vectorized environment.
|
||||
|
||||
rewards : `np.ndarray` instance (dtype `np.float_`)
|
||||
A vector of rewards from the vectorized environment.
|
||||
|
||||
dones : `np.ndarray` instance (dtype `np.bool_`)
|
||||
A vector whose entries indicate whether the episode has ended.
|
||||
|
||||
infos : list of dict
|
||||
A list of auxiliary diagnostic informations.
|
||||
"""
|
||||
observations, infos = [], []
|
||||
for i, (env, action) in enumerate(zip(self.envs, actions)):
|
||||
observation, self._rewards[i], self._dones[i], info = env.step(action)
|
||||
if self._dones[i]:
|
||||
observation = env.reset()
|
||||
observations.append(observation)
|
||||
infos.append(info)
|
||||
concatenate(observations, self.observations, self.single_observation_space)
|
||||
|
||||
return (np.copy(self.observations) if self.copy else self.observations,
|
||||
np.copy(self._rewards), np.copy(self._dones), infos)
|
||||
|
||||
def close(self):
|
||||
if self.closed:
|
||||
return
|
||||
if self.viewer is not None:
|
||||
self.viewer.close()
|
||||
|
||||
for env in self.envs:
|
||||
env.close()
|
||||
|
||||
self.closed = True
|
||||
|
||||
def _check_observation_spaces(self):
|
||||
for env in self.envs:
|
||||
if not (env.observation_space == self.single_observation_space):
|
||||
break
|
||||
else:
|
||||
return True
|
||||
raise RuntimeError('Some environments have an observation space '
|
||||
'different from `{0}`. In order to batch observations, the '
|
||||
'observation spaces from all environments must be '
|
||||
'equal.'.format(self.single_observation_space))
|
0
gym/vector/tests/__init__.py
Normal file
0
gym/vector/tests/__init__.py
Normal file
192
gym/vector/tests/test_async_vector_env.py
Normal file
192
gym/vector/tests/test_async_vector_env.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from multiprocessing import TimeoutError
|
||||
from gym.spaces import Box
|
||||
from gym.error import (AlreadyPendingCallError, NoAsyncCallError,
|
||||
ClosedEnvironmentError)
|
||||
from gym.vector.tests.utils import make_env, make_slow_env
|
||||
|
||||
from gym.vector.async_vector_env import AsyncVectorEnv
|
||||
|
||||
@pytest.mark.parametrize('shared_memory', [True, False])
|
||||
def test_create_async_vector_env(shared_memory):
|
||||
env_fns = [make_env('CubeCrash-v0', i) for i in range(8)]
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
assert env.num_envs == 8
|
||||
|
||||
|
||||
@pytest.mark.parametrize('shared_memory', [True, False])
|
||||
def test_reset_async_vector_env(shared_memory):
|
||||
env_fns = [make_env('CubeCrash-v0', i) for i in range(8)]
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
observations = env.reset()
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
assert isinstance(env.observation_space, Box)
|
||||
assert isinstance(observations, np.ndarray)
|
||||
assert observations.dtype == env.observation_space.dtype
|
||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||
assert observations.shape == env.observation_space.shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize('shared_memory', [True, False])
|
||||
def test_step_async_vector_env(shared_memory):
|
||||
env_fns = [make_env('CubeCrash-v0', i) for i in range(8)]
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
observations = env.reset()
|
||||
actions = [env.single_action_space.sample() for _ in range(8)]
|
||||
observations, rewards, dones, _ = env.step(actions)
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
assert isinstance(env.observation_space, Box)
|
||||
assert isinstance(observations, np.ndarray)
|
||||
assert observations.dtype == env.observation_space.dtype
|
||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||
assert observations.shape == env.observation_space.shape
|
||||
|
||||
assert isinstance(rewards, np.ndarray)
|
||||
assert isinstance(rewards[0], (float, np.floating))
|
||||
assert rewards.ndim == 1
|
||||
assert rewards.size == 8
|
||||
|
||||
assert isinstance(dones, np.ndarray)
|
||||
assert dones.dtype == np.bool_
|
||||
assert dones.ndim == 1
|
||||
assert dones.size == 8
|
||||
|
||||
|
||||
@pytest.mark.parametrize('shared_memory', [True, False])
|
||||
def test_copy_async_vector_env(shared_memory):
|
||||
env_fns = [make_env('CubeCrash-v0', i) for i in range(8)]
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory,
|
||||
copy=True)
|
||||
observations = env.reset()
|
||||
observations[0] = 128
|
||||
assert not np.all(env.observations[0] == 128)
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('shared_memory', [True, False])
|
||||
def test_no_copy_async_vector_env(shared_memory):
|
||||
env_fns = [make_env('CubeCrash-v0', i) for i in range(8)]
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory,
|
||||
copy=False)
|
||||
observations = env.reset()
|
||||
observations[0] = 128
|
||||
assert np.all(env.observations[0] == 128)
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('shared_memory', [True, False])
|
||||
def test_reset_timeout_async_vector_env(shared_memory):
|
||||
env_fns = [make_slow_env(0.3, i) for i in range(4)]
|
||||
with pytest.raises(TimeoutError):
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
env.reset_async()
|
||||
observations = env.reset_wait(timeout=0.1)
|
||||
finally:
|
||||
env.close(terminate=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('shared_memory', [True, False])
|
||||
def test_step_timeout_async_vector_env(shared_memory):
|
||||
env_fns = [make_slow_env(0., i) for i in range(4)]
|
||||
with pytest.raises(TimeoutError):
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
observations = env.reset()
|
||||
env.step_async([0.1, 0.1, 0.3, 0.1])
|
||||
observations, rewards, dones, _ = env.step_wait(timeout=0.1)
|
||||
finally:
|
||||
env.close(terminate=True)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings('ignore::UserWarning')
|
||||
@pytest.mark.parametrize('shared_memory', [True, False])
|
||||
def test_reset_out_of_order_async_vector_env(shared_memory):
|
||||
env_fns = [make_env('CubeCrash-v0', i) for i in range(4)]
|
||||
with pytest.raises(NoAsyncCallError):
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
observations = env.reset_wait()
|
||||
except NoAsyncCallError as exception:
|
||||
assert exception.name == 'reset'
|
||||
raise
|
||||
finally:
|
||||
env.close(terminate=True)
|
||||
|
||||
with pytest.raises(AlreadyPendingCallError):
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
actions = env.action_space.sample()
|
||||
observations = env.reset()
|
||||
env.step_async(actions)
|
||||
env.reset_async()
|
||||
except NoAsyncCallError as exception:
|
||||
assert exception.name == 'step'
|
||||
raise
|
||||
finally:
|
||||
env.close(terminate=True)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings('ignore::UserWarning')
|
||||
@pytest.mark.parametrize('shared_memory', [True, False])
|
||||
def test_step_out_of_order_async_vector_env(shared_memory):
|
||||
env_fns = [make_env('CubeCrash-v0', i) for i in range(4)]
|
||||
with pytest.raises(NoAsyncCallError):
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
actions = env.action_space.sample()
|
||||
observations = env.reset()
|
||||
observations, rewards, dones, infos = env.step_wait()
|
||||
except AlreadyPendingCallError as exception:
|
||||
assert exception.name == 'step'
|
||||
raise
|
||||
finally:
|
||||
env.close(terminate=True)
|
||||
|
||||
with pytest.raises(AlreadyPendingCallError):
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
actions = env.action_space.sample()
|
||||
env.reset_async()
|
||||
env.step_async(actions)
|
||||
except AlreadyPendingCallError as exception:
|
||||
assert exception.name == 'reset'
|
||||
raise
|
||||
finally:
|
||||
env.close(terminate=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('shared_memory', [True, False])
|
||||
def test_already_closed_async_vector_env(shared_memory):
|
||||
env_fns = [make_env('CubeCrash-v0', i) for i in range(4)]
|
||||
with pytest.raises(ClosedEnvironmentError):
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
env.close()
|
||||
observations = env.reset()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('shared_memory', [True, False])
|
||||
def test_check_observations_async_vector_env(shared_memory):
|
||||
# CubeCrash-v0 - observation_space: Box(40, 32, 3)
|
||||
env_fns = [make_env('CubeCrash-v0', i) for i in range(8)]
|
||||
# MemorizeDigits-v0 - observation_space: Box(24, 32, 3)
|
||||
env_fns[1] = make_env('MemorizeDigits-v0', 1)
|
||||
with pytest.raises(RuntimeError):
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
env.close(terminate=True)
|
141
gym/vector/tests/test_numpy_utils.py
Normal file
141
gym/vector/tests/test_numpy_utils.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
from gym.spaces import Tuple, Dict
|
||||
from gym.vector.utils.spaces import _BaseGymSpaces
|
||||
from gym.vector.tests.utils import spaces
|
||||
|
||||
from gym.vector.utils.numpy_utils import concatenate, create_empty_array
|
||||
|
||||
@pytest.mark.parametrize('space', spaces,
|
||||
ids=[space.__class__.__name__ for space in spaces])
|
||||
def test_concatenate(space):
|
||||
def assert_type(lhs, rhs, n):
|
||||
# Special case: if rhs is a list of scalars, lhs must be an np.ndarray
|
||||
if np.isscalar(rhs[0]):
|
||||
assert isinstance(lhs, np.ndarray)
|
||||
assert all([np.isscalar(rhs[i]) for i in range(n)])
|
||||
else:
|
||||
assert all([isinstance(rhs[i], type(lhs)) for i in range(n)])
|
||||
|
||||
def assert_nested_equal(lhs, rhs, n):
|
||||
assert isinstance(rhs, list)
|
||||
assert (n > 0) and (len(rhs) == n)
|
||||
assert_type(lhs, rhs, n)
|
||||
if isinstance(lhs, np.ndarray):
|
||||
assert lhs.shape[0] == n
|
||||
for i in range(n):
|
||||
assert np.all(lhs[i] == rhs[i])
|
||||
|
||||
elif isinstance(lhs, tuple):
|
||||
for i in range(len(lhs)):
|
||||
rhs_T_i = [rhs[j][i] for j in range(n)]
|
||||
assert_nested_equal(lhs[i], rhs_T_i, n)
|
||||
|
||||
elif isinstance(lhs, OrderedDict):
|
||||
for key in lhs.keys():
|
||||
rhs_T_key = [rhs[j][key] for j in range(n)]
|
||||
assert_nested_equal(lhs[key], rhs_T_key, n)
|
||||
|
||||
else:
|
||||
raise TypeError('Got unknown type `{0}`.'.format(type(lhs)))
|
||||
|
||||
samples = [space.sample() for _ in range(8)]
|
||||
array = create_empty_array(space, n=8)
|
||||
concatenated = concatenate(samples, array, space)
|
||||
|
||||
assert np.all(concatenated == array)
|
||||
assert_nested_equal(array, samples, n=8)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('n', [1, 8])
|
||||
@pytest.mark.parametrize('space', spaces,
|
||||
ids=[space.__class__.__name__ for space in spaces])
|
||||
def test_create_empty_array(space, n):
|
||||
|
||||
def assert_nested_type(arr, space, n):
|
||||
if isinstance(space, _BaseGymSpaces):
|
||||
assert isinstance(arr, np.ndarray)
|
||||
assert arr.dtype == space.dtype
|
||||
assert arr.shape == (n,) + space.shape
|
||||
|
||||
elif isinstance(space, Tuple):
|
||||
assert isinstance(arr, tuple)
|
||||
assert len(arr) == len(space.spaces)
|
||||
for i in range(len(arr)):
|
||||
assert_nested_type(arr[i], space.spaces[i], n)
|
||||
|
||||
elif isinstance(space, Dict):
|
||||
assert isinstance(arr, OrderedDict)
|
||||
assert set(arr.keys()) ^ set(space.spaces.keys()) == set()
|
||||
for key in arr.keys():
|
||||
assert_nested_type(arr[key], space.spaces[key], n)
|
||||
|
||||
else:
|
||||
raise TypeError('Got unknown type `{0}`.'.format(type(arr)))
|
||||
|
||||
array = create_empty_array(space, n=n, fn=np.empty)
|
||||
assert_nested_type(array, space, n=n)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('n', [1, 8])
|
||||
@pytest.mark.parametrize('space', spaces,
|
||||
ids=[space.__class__.__name__ for space in spaces])
|
||||
def test_create_empty_array_zeros(space, n):
|
||||
|
||||
def assert_nested_type(arr, space, n):
|
||||
if isinstance(space, _BaseGymSpaces):
|
||||
assert isinstance(arr, np.ndarray)
|
||||
assert arr.dtype == space.dtype
|
||||
assert arr.shape == (n,) + space.shape
|
||||
assert np.all(arr == 0)
|
||||
|
||||
elif isinstance(space, Tuple):
|
||||
assert isinstance(arr, tuple)
|
||||
assert len(arr) == len(space.spaces)
|
||||
for i in range(len(arr)):
|
||||
assert_nested_type(arr[i], space.spaces[i], n)
|
||||
|
||||
elif isinstance(space, Dict):
|
||||
assert isinstance(arr, OrderedDict)
|
||||
assert set(arr.keys()) ^ set(space.spaces.keys()) == set()
|
||||
for key in arr.keys():
|
||||
assert_nested_type(arr[key], space.spaces[key], n)
|
||||
|
||||
else:
|
||||
raise TypeError('Got unknown type `{0}`.'.format(type(arr)))
|
||||
|
||||
array = create_empty_array(space, n=n, fn=np.zeros)
|
||||
assert_nested_type(array, space, n=n)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('space', spaces,
|
||||
ids=[space.__class__.__name__ for space in spaces])
|
||||
def test_create_empty_array_none_shape_ones(space):
|
||||
|
||||
def assert_nested_type(arr, space):
|
||||
if isinstance(space, _BaseGymSpaces):
|
||||
assert isinstance(arr, np.ndarray)
|
||||
assert arr.dtype == space.dtype
|
||||
assert arr.shape == space.shape
|
||||
assert np.all(arr == 1)
|
||||
|
||||
elif isinstance(space, Tuple):
|
||||
assert isinstance(arr, tuple)
|
||||
assert len(arr) == len(space.spaces)
|
||||
for i in range(len(arr)):
|
||||
assert_nested_type(arr[i], space.spaces[i])
|
||||
|
||||
elif isinstance(space, Dict):
|
||||
assert isinstance(arr, OrderedDict)
|
||||
assert set(arr.keys()) ^ set(space.spaces.keys()) == set()
|
||||
for key in arr.keys():
|
||||
assert_nested_type(arr[key], space.spaces[key])
|
||||
|
||||
else:
|
||||
raise TypeError('Got unknown type `{0}`.'.format(type(arr)))
|
||||
|
||||
array = create_empty_array(space, n=None, fn=np.ones)
|
||||
assert_nested_type(array, space)
|
137
gym/vector/tests/test_shared_memory.py
Normal file
137
gym/vector/tests/test_shared_memory.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from multiprocessing.sharedctypes import SynchronizedArray
|
||||
from multiprocessing import Array, Process
|
||||
from collections import OrderedDict
|
||||
|
||||
from gym.spaces import Tuple, Dict
|
||||
from gym.vector.utils.spaces import _BaseGymSpaces
|
||||
from gym.vector.tests.utils import spaces
|
||||
|
||||
from gym.vector.utils.shared_memory import (create_shared_memory,
|
||||
read_from_shared_memory, write_to_shared_memory)
|
||||
|
||||
expected_types = [
|
||||
Array('d', 1), Array('f', 1), Array('f', 3), Array('f', 4), Array('B', 1), Array('B', 32 * 32 * 3),
|
||||
Array('i', 1), (Array('i', 1), Array('i', 1)), (Array('i', 1), Array('f', 2)),
|
||||
Array('B', 3), Array('B', 19),
|
||||
OrderedDict([
|
||||
('position', Array('i', 1)),
|
||||
('velocity', Array('f', 1))
|
||||
]),
|
||||
OrderedDict([
|
||||
('position', OrderedDict([('x', Array('i', 1)), ('y', Array('i', 1))])),
|
||||
('velocity', (Array('i', 1), Array('B', 1)))
|
||||
])
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize('n', [1, 8])
|
||||
@pytest.mark.parametrize('space,expected_type', list(zip(spaces, expected_types)),
|
||||
ids=[space.__class__.__name__ for space in spaces])
|
||||
def test_create_shared_memory(space, expected_type, n):
|
||||
def assert_nested_type(lhs, rhs, n):
|
||||
assert type(lhs) == type(rhs)
|
||||
if isinstance(lhs, (list, tuple)):
|
||||
assert len(lhs) == len(rhs)
|
||||
for lhs_, rhs_ in zip(lhs, rhs):
|
||||
assert_nested_type(lhs_, rhs_, n)
|
||||
|
||||
elif isinstance(lhs, (dict, OrderedDict)):
|
||||
assert set(lhs.keys()) ^ set(rhs.keys()) == set()
|
||||
for key in lhs.keys():
|
||||
assert_nested_type(lhs[key], rhs[key], n)
|
||||
|
||||
elif isinstance(lhs, SynchronizedArray):
|
||||
# Assert the length of the array
|
||||
assert len(lhs[:]) == n * len(rhs[:])
|
||||
# Assert the data type
|
||||
assert type(lhs[0]) == type(rhs[0])
|
||||
|
||||
else:
|
||||
raise TypeError('Got unknown type `{0}`.'.format(type(lhs)))
|
||||
|
||||
shared_memory = create_shared_memory(space, n=n)
|
||||
assert_nested_type(shared_memory, expected_type, n=n)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('space', spaces,
|
||||
ids=[space.__class__.__name__ for space in spaces])
|
||||
def test_write_to_shared_memory(space):
|
||||
|
||||
def assert_nested_equal(lhs, rhs):
|
||||
assert isinstance(rhs, list)
|
||||
if isinstance(lhs, (list, tuple)):
|
||||
for i in range(len(lhs)):
|
||||
assert_nested_equal(lhs[i], [rhs_[i] for rhs_ in rhs])
|
||||
|
||||
elif isinstance(lhs, (dict, OrderedDict)):
|
||||
for key in lhs.keys():
|
||||
assert_nested_equal(lhs[key], [rhs_[key] for rhs_ in rhs])
|
||||
|
||||
elif isinstance(lhs, SynchronizedArray):
|
||||
assert np.all(np.array(lhs[:]) == np.stack(rhs, axis=0).flatten())
|
||||
|
||||
else:
|
||||
raise TypeError('Got unknown type `{0}`.'.format(type(lhs)))
|
||||
|
||||
def write(i, shared_memory, sample):
|
||||
write_to_shared_memory(i, sample, shared_memory, space)
|
||||
|
||||
shared_memory_n8 = create_shared_memory(space, n=8)
|
||||
samples = [space.sample() for _ in range(8)]
|
||||
|
||||
processes = [Process(target=write, args=(i, shared_memory_n8,
|
||||
samples[i])) for i in range(8)]
|
||||
|
||||
for process in processes:
|
||||
process.start()
|
||||
for process in processes:
|
||||
process.join()
|
||||
|
||||
assert_nested_equal(shared_memory_n8, samples)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('space', spaces,
|
||||
ids=[space.__class__.__name__ for space in spaces])
|
||||
def test_read_from_shared_memory(space):
|
||||
|
||||
def assert_nested_equal(lhs, rhs, space, n):
|
||||
assert isinstance(rhs, list)
|
||||
if isinstance(space, Tuple):
|
||||
assert isinstance(lhs, tuple)
|
||||
for i in range(len(lhs)):
|
||||
assert_nested_equal(lhs[i], [rhs_[i] for rhs_ in rhs],
|
||||
space.spaces[i], n)
|
||||
|
||||
elif isinstance(space, Dict):
|
||||
assert isinstance(lhs, OrderedDict)
|
||||
for key in lhs.keys():
|
||||
assert_nested_equal(lhs[key], [rhs_[key] for rhs_ in rhs],
|
||||
space.spaces[key], n)
|
||||
|
||||
elif isinstance(space, _BaseGymSpaces):
|
||||
assert isinstance(lhs, np.ndarray)
|
||||
assert lhs.shape == ((n,) + space.shape)
|
||||
assert lhs.dtype == space.dtype
|
||||
assert np.all(lhs == np.stack(rhs, axis=0))
|
||||
|
||||
else:
|
||||
raise TypeError('Got unknown type `{0}`'.format(type(space)))
|
||||
|
||||
def write(i, shared_memory, sample):
|
||||
write_to_shared_memory(i, sample, shared_memory, space)
|
||||
|
||||
shared_memory_n8 = create_shared_memory(space, n=8)
|
||||
memory_view_n8 = read_from_shared_memory(shared_memory_n8, space, n=8)
|
||||
samples = [space.sample() for _ in range(8)]
|
||||
|
||||
processes = [Process(target=write, args=(i, shared_memory_n8,
|
||||
samples[i])) for i in range(8)]
|
||||
|
||||
for process in processes:
|
||||
process.start()
|
||||
for process in processes:
|
||||
process.join()
|
||||
|
||||
assert_nested_equal(memory_view_n8, samples, space, n=8)
|
39
gym/vector/tests/test_spaces.py
Normal file
39
gym/vector/tests/test_spaces.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from gym.spaces import Box, MultiDiscrete, Tuple, Dict
|
||||
from gym.vector.tests.utils import spaces
|
||||
|
||||
from gym.vector.utils.spaces import _BaseGymSpaces, batch_space
|
||||
|
||||
expected_batch_spaces_4 = [
|
||||
Box(low=-1., high=1., shape=(4,), dtype=np.float64),
|
||||
Box(low=0., high=10., shape=(4, 1), dtype=np.float32),
|
||||
Box(low=np.array([[-1., 0., 0.], [-1., 0., 0.], [-1., 0., 0.], [-1., 0., 0.]]),
|
||||
high=np.array([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]), dtype=np.float32),
|
||||
Box(low=np.array([[[-1., 0.], [0., -1.]], [[-1., 0.], [0., -1.]], [[-1., 0.], [0., -1]],
|
||||
[[-1., 0.], [0., -1.]]]), high=np.ones((4, 2, 2)), dtype=np.float32),
|
||||
Box(low=0, high=255, shape=(4,), dtype=np.uint8),
|
||||
Box(low=0, high=255, shape=(4, 32, 32, 3), dtype=np.uint8),
|
||||
MultiDiscrete([2, 2, 2, 2]),
|
||||
Tuple((MultiDiscrete([3, 3, 3, 3]), MultiDiscrete([5, 5, 5, 5]))),
|
||||
Tuple((MultiDiscrete([7, 7, 7, 7]), Box(low=np.array([[0., -1.], [0., -1.], [0., -1.], [0., -1]]),
|
||||
high=np.array([[1., 1.], [1., 1.], [1., 1.], [1., 1.]]), dtype=np.float32))),
|
||||
Box(low=np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]),
|
||||
high=np.array([[10, 12, 16], [10, 12, 16], [10, 12, 16], [10, 12, 16]]), dtype=np.int64),
|
||||
Box(low=0, high=1, shape=(4, 19), dtype=np.int8),
|
||||
Dict({
|
||||
'position': MultiDiscrete([23, 23, 23, 23]),
|
||||
'velocity': Box(low=0., high=1., shape=(4, 1), dtype=np.float32)
|
||||
}),
|
||||
Dict({
|
||||
'position': Dict({'x': MultiDiscrete([29, 29, 29, 29]), 'y': MultiDiscrete([31, 31, 31, 31])}),
|
||||
'velocity': Tuple((MultiDiscrete([37, 37, 37, 37]), Box(low=0, high=255, shape=(4,), dtype=np.uint8)))
|
||||
})
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize('space,expected_batch_space_4', list(zip(spaces,
|
||||
expected_batch_spaces_4)), ids=[space.__class__.__name__ for space in spaces])
|
||||
def test_batch_space(space, expected_batch_space_4):
|
||||
batch_space_4 = batch_space(space, n=4)
|
||||
assert batch_space_4 == expected_batch_space_4
|
68
gym/vector/tests/test_sync_vector_env.py
Normal file
68
gym/vector/tests/test_sync_vector_env.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from gym.spaces import Box
|
||||
from gym.vector.tests.utils import make_env
|
||||
|
||||
from gym.vector.sync_vector_env import SyncVectorEnv
|
||||
|
||||
def test_create_sync_vector_env():
|
||||
env_fns = [make_env('CubeCrash-v0', i) for i in range(8)]
|
||||
try:
|
||||
env = SyncVectorEnv(env_fns)
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
assert env.num_envs == 8
|
||||
|
||||
|
||||
def test_reset_sync_vector_env():
|
||||
env_fns = [make_env('CubeCrash-v0', i) for i in range(8)]
|
||||
try:
|
||||
env = SyncVectorEnv(env_fns)
|
||||
observations = env.reset()
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
assert isinstance(env.observation_space, Box)
|
||||
assert isinstance(observations, np.ndarray)
|
||||
assert observations.dtype == env.observation_space.dtype
|
||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||
assert observations.shape == env.observation_space.shape
|
||||
|
||||
|
||||
def test_step_sync_vector_env():
|
||||
env_fns = [make_env('CubeCrash-v0', i) for i in range(8)]
|
||||
try:
|
||||
env = SyncVectorEnv(env_fns)
|
||||
observations = env.reset()
|
||||
actions = [env.single_action_space.sample() for _ in range(8)]
|
||||
observations, rewards, dones, _ = env.step(actions)
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
assert isinstance(env.observation_space, Box)
|
||||
assert isinstance(observations, np.ndarray)
|
||||
assert observations.dtype == env.observation_space.dtype
|
||||
assert observations.shape == (8,) + env.single_observation_space.shape
|
||||
assert observations.shape == env.observation_space.shape
|
||||
|
||||
assert isinstance(rewards, np.ndarray)
|
||||
assert isinstance(rewards[0], (float, np.floating))
|
||||
assert rewards.ndim == 1
|
||||
assert rewards.size == 8
|
||||
|
||||
assert isinstance(dones, np.ndarray)
|
||||
assert dones.dtype == np.bool_
|
||||
assert dones.ndim == 1
|
||||
assert dones.size == 8
|
||||
|
||||
|
||||
def test_check_observations_sync_vector_env():
|
||||
# CubeCrash-v0 - observation_space: Box(40, 32, 3)
|
||||
env_fns = [make_env('CubeCrash-v0', i) for i in range(8)]
|
||||
# MemorizeDigits-v0 - observation_space: Box(24, 32, 3)
|
||||
env_fns[1] = make_env('MemorizeDigits-v0', 1)
|
||||
with pytest.raises(RuntimeError):
|
||||
env = SyncVectorEnv(env_fns)
|
||||
env.close()
|
43
gym/vector/tests/test_vector_env.py
Normal file
43
gym/vector/tests/test_vector_env.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from gym.vector.tests.utils import make_env
|
||||
|
||||
from gym.vector.async_vector_env import AsyncVectorEnv
|
||||
from gym.vector.sync_vector_env import SyncVectorEnv
|
||||
|
||||
@pytest.mark.parametrize('shared_memory', [True, False])
|
||||
def test_vector_env_equal(shared_memory):
|
||||
env_fns = [make_env('CubeCrash-v0', i) for i in range(4)]
|
||||
num_steps = 100
|
||||
try:
|
||||
async_env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
sync_env = SyncVectorEnv(env_fns)
|
||||
|
||||
async_env.seed(0)
|
||||
sync_env.seed(0)
|
||||
|
||||
assert async_env.num_envs == sync_env.num_envs
|
||||
assert async_env.observation_space == sync_env.observation_space
|
||||
assert async_env.single_observation_space == sync_env.single_observation_space
|
||||
assert async_env.action_space == sync_env.action_space
|
||||
assert async_env.single_action_space == sync_env.single_action_space
|
||||
|
||||
async_observations = async_env.reset()
|
||||
sync_observations = sync_env.reset()
|
||||
assert np.all(async_observations == sync_observations)
|
||||
|
||||
for _ in range(num_steps):
|
||||
actions = async_env.action_space.sample()
|
||||
assert actions in sync_env.action_space
|
||||
|
||||
async_observations, async_rewards, async_dones, _ = async_env.step(actions)
|
||||
sync_observations, sync_rewards, sync_dones, _ = sync_env.step(actions)
|
||||
|
||||
assert np.all(async_observations == sync_observations)
|
||||
assert np.all(async_rewards == sync_rewards)
|
||||
assert np.all(async_dones == sync_dones)
|
||||
|
||||
finally:
|
||||
async_env.close()
|
||||
sync_env.close()
|
62
gym/vector/tests/utils.py
Normal file
62
gym/vector/tests/utils.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import numpy as np
|
||||
import gym
|
||||
import time
|
||||
|
||||
from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict
|
||||
|
||||
spaces = [
|
||||
Box(low=np.array(-1.), high=np.array(1.), dtype=np.float64),
|
||||
Box(low=np.array([0.]), high=np.array([10.]), dtype=np.float32),
|
||||
Box(low=np.array([-1., 0., 0.]), high=np.array([1., 1., 1.]), dtype=np.float32),
|
||||
Box(low=np.array([[-1., 0.], [0., -1.]]), high=np.ones((2, 2)), dtype=np.float32),
|
||||
Box(low=0, high=255, shape=(), dtype=np.uint8),
|
||||
Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
|
||||
Discrete(2),
|
||||
Tuple((Discrete(3), Discrete(5))),
|
||||
Tuple((Discrete(7), Box(low=np.array([0., -1.]), high=np.array([1., 1.]), dtype=np.float32))),
|
||||
MultiDiscrete([11, 13, 17]),
|
||||
MultiBinary(19),
|
||||
Dict({
|
||||
'position': Discrete(23),
|
||||
'velocity': Box(low=np.array([0.]), high=np.array([1.]), dtype=np.float32)
|
||||
}),
|
||||
Dict({
|
||||
'position': Dict({'x': Discrete(29), 'y': Discrete(31)}),
|
||||
'velocity': Tuple((Discrete(37), Box(low=0, high=255, shape=(), dtype=np.uint8)))
|
||||
})
|
||||
]
|
||||
|
||||
HEIGHT, WIDTH = 64, 64
|
||||
|
||||
class UnittestSlowEnv(gym.Env):
|
||||
def __init__(self, slow_reset=0.3):
|
||||
super(UnittestSlowEnv, self).__init__()
|
||||
self.slow_reset = slow_reset
|
||||
self.observation_space = Box(low=0, high=255,
|
||||
shape=(HEIGHT, WIDTH, 3), dtype=np.uint8)
|
||||
self.action_space = Box(low=0., high=1., shape=(), dtype=np.float32)
|
||||
|
||||
def reset(self):
|
||||
if self.slow_reset > 0:
|
||||
time.sleep(self.slow_reset)
|
||||
return self.observation_space.sample()
|
||||
|
||||
def step(self, action):
|
||||
time.sleep(action)
|
||||
observation = self.observation_space.sample()
|
||||
reward, done = 0., False
|
||||
return observation, reward, done, {}
|
||||
|
||||
def make_env(env_name, seed):
|
||||
def _make():
|
||||
env = gym.make(env_name)
|
||||
env.seed(seed)
|
||||
return env
|
||||
return _make
|
||||
|
||||
def make_slow_env(slow_reset, seed):
|
||||
def _make():
|
||||
env = UnittestSlowEnv(slow_reset=slow_reset)
|
||||
env.seed(seed)
|
||||
return env
|
||||
return _make
|
16
gym/vector/utils/__init__.py
Normal file
16
gym/vector/utils/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from gym.vector.utils.misc import CloudpickleWrapper, clear_mpi_env_vars
|
||||
from gym.vector.utils.numpy_utils import concatenate, create_empty_array
|
||||
from gym.vector.utils.shared_memory import create_shared_memory, read_from_shared_memory, write_to_shared_memory
|
||||
from gym.vector.utils.spaces import _BaseGymSpaces, batch_space
|
||||
|
||||
__all__ = [
|
||||
'CloudpickleWrapper',
|
||||
'clear_mpi_env_vars',
|
||||
'concatenate',
|
||||
'create_empty_array',
|
||||
'create_shared_memory',
|
||||
'read_from_shared_memory',
|
||||
'write_to_shared_memory',
|
||||
'_BaseGymSpaces',
|
||||
'batch_space'
|
||||
]
|
40
gym/vector/utils/misc.py
Normal file
40
gym/vector/utils/misc.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import contextlib
|
||||
import os
|
||||
|
||||
__all__ = ['CloudpickleWrapper', 'clear_mpi_env_vars']
|
||||
|
||||
class CloudpickleWrapper(object):
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def __getstate__(self):
|
||||
import cloudpickle
|
||||
return cloudpickle.dumps(self.fn)
|
||||
|
||||
def __setstate__(self, ob):
|
||||
import pickle
|
||||
self.fn = pickle.loads(ob)
|
||||
|
||||
def __call__(self):
|
||||
return self.fn()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def clear_mpi_env_vars():
|
||||
"""
|
||||
`from mpi4py import MPI` will call `MPI_Init` by default. If the child
|
||||
process has MPI environment variables, MPI will think that the child process
|
||||
is an MPI process just like the parent and do bad things such as hang.
|
||||
|
||||
This context manager is a hacky way to clear those environment variables
|
||||
temporarily such as when we are starting multiprocessing Processes.
|
||||
"""
|
||||
removed_environment = {}
|
||||
for k, v in list(os.environ.items()):
|
||||
for prefix in ['OMPI_', 'PMI_']:
|
||||
if k.startswith(prefix):
|
||||
removed_environment[k] = v
|
||||
del os.environ[k]
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.environ.update(removed_environment)
|
112
gym/vector/utils/numpy_utils.py
Normal file
112
gym/vector/utils/numpy_utils.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import numpy as np
|
||||
|
||||
from gym.spaces import Tuple, Dict
|
||||
from gym.vector.utils.spaces import _BaseGymSpaces
|
||||
from collections import OrderedDict
|
||||
|
||||
__all__ = ['concatenate', 'create_empty_array']
|
||||
|
||||
def concatenate(items, out, space):
|
||||
"""Concatenate multiple samples from space into a single object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
items : iterable of samples of `space`
|
||||
Samples to be concatenated.
|
||||
|
||||
out : tuple, dict, or `np.ndarray`
|
||||
The output object. This object is a (possibly nested) numpy array.
|
||||
|
||||
space : `gym.spaces.Space` instance
|
||||
Observation space of a single environment in the vectorized environment.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : tuple, dict, or `np.ndarray`
|
||||
The output object. This object is a (possibly nested) numpy array.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> from gym.spaces import Box
|
||||
>>> space = Box(low=0, high=1, shape=(3,), dtype=np.float32)
|
||||
>>> out = np.zeros((2, 3), dtype=np.float32)
|
||||
>>> items = [space.sample() for _ in range(2)]
|
||||
>>> concatenate(items, out, space)
|
||||
array([[0.6348213 , 0.28607962, 0.60760117],
|
||||
[0.87383074, 0.192658 , 0.2148103 ]], dtype=float32)
|
||||
"""
|
||||
assert isinstance(items, (list, tuple))
|
||||
if isinstance(space, _BaseGymSpaces):
|
||||
return concatenate_base(items, out, space)
|
||||
elif isinstance(space, Tuple):
|
||||
return concatenate_tuple(items, out, space)
|
||||
elif isinstance(space, Dict):
|
||||
return concatenate_dict(items, out, space)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def concatenate_base(items, out, space):
|
||||
return np.stack(items, axis=0, out=out)
|
||||
|
||||
def concatenate_tuple(items, out, space):
|
||||
return tuple(concatenate([item[i] for item in items],
|
||||
out[i], subspace) for (i, subspace) in enumerate(space.spaces))
|
||||
|
||||
def concatenate_dict(items, out, space):
|
||||
return OrderedDict([(key, concatenate([item[key] for item in items],
|
||||
out[key], subspace)) for (key, subspace) in space.spaces.items()])
|
||||
|
||||
|
||||
def create_empty_array(space, n=1, fn=np.zeros):
|
||||
"""Create an empty (possibly nested) numpy array.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
space : `gym.spaces.Space` instance
|
||||
Observation space of a single environment in the vectorized environment.
|
||||
|
||||
n : int
|
||||
Number of environments in the vectorized environment. If `None`, creates
|
||||
an empty sample from `space`.
|
||||
|
||||
fn : callable
|
||||
Function to apply when creating the empty numpy array. Examples of such
|
||||
functions are `np.empty` or `np.zeros`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : tuple, dict, or `np.ndarray`
|
||||
The output object. This object is a (possibly nested) numpy array.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> from gym.spaces import Box, Dict
|
||||
>>> space = Dict({
|
||||
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
|
||||
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
|
||||
>>> create_empty_array(space, n=2, fn=np.zeros)
|
||||
OrderedDict([('position', array([[0., 0., 0.],
|
||||
[0., 0., 0.]], dtype=float32)),
|
||||
('velocity', array([[0., 0.],
|
||||
[0., 0.]], dtype=float32))])
|
||||
"""
|
||||
if isinstance(space, _BaseGymSpaces):
|
||||
return create_empty_array_base(space, n=n, fn=fn)
|
||||
elif isinstance(space, Tuple):
|
||||
return create_empty_array_tuple(space, n=n, fn=fn)
|
||||
elif isinstance(space, Dict):
|
||||
return create_empty_array_dict(space, n=n, fn=fn)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def create_empty_array_base(space, n=1, fn=np.zeros):
|
||||
shape = space.shape if (n is None) else (n,) + space.shape
|
||||
return fn(shape, dtype=space.dtype)
|
||||
|
||||
def create_empty_array_tuple(space, n=1, fn=np.zeros):
|
||||
return tuple(create_empty_array(subspace, n=n, fn=fn)
|
||||
for subspace in space.spaces)
|
||||
|
||||
def create_empty_array_dict(space, n=1, fn=np.zeros):
|
||||
return OrderedDict([(key, create_empty_array(subspace, n=n, fn=fn))
|
||||
for (key, subspace) in space.spaces.items()])
|
150
gym/vector/utils/shared_memory.py
Normal file
150
gym/vector/utils/shared_memory.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import numpy as np
|
||||
from multiprocessing import Array
|
||||
from ctypes import c_bool
|
||||
from collections import OrderedDict
|
||||
|
||||
from gym import logger
|
||||
from gym.spaces import Tuple, Dict
|
||||
from gym.vector.utils.spaces import _BaseGymSpaces
|
||||
|
||||
__all__ = [
|
||||
'create_shared_memory',
|
||||
'read_from_shared_memory',
|
||||
'write_to_shared_memory'
|
||||
]
|
||||
|
||||
def create_shared_memory(space, n=1):
|
||||
"""Create a shared memory object, to be shared across processes. This
|
||||
eventually contains the observations from the vectorized environment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
space : `gym.spaces.Space` instance
|
||||
Observation space of a single environment in the vectorized environment.
|
||||
|
||||
n : int
|
||||
Number of environments in the vectorized environment (i.e. the number
|
||||
of processes).
|
||||
|
||||
Returns
|
||||
-------
|
||||
shared_memory : dict, tuple, or `multiprocessing.Array` instance
|
||||
Shared object across processes.
|
||||
"""
|
||||
if isinstance(space, _BaseGymSpaces):
|
||||
return create_base_shared_memory(space, n=n)
|
||||
elif isinstance(space, Tuple):
|
||||
return create_tuple_shared_memory(space, n=n)
|
||||
elif isinstance(space, Dict):
|
||||
return create_dict_shared_memory(space, n=n)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def create_base_shared_memory(space, n=1):
|
||||
dtype = space.dtype.char
|
||||
if dtype in '?':
|
||||
dtype = c_bool
|
||||
return Array(dtype, n * int(np.prod(space.shape)))
|
||||
|
||||
def create_tuple_shared_memory(space, n=1):
|
||||
return tuple(create_shared_memory(subspace, n=n)
|
||||
for subspace in space.spaces)
|
||||
|
||||
def create_dict_shared_memory(space, n=1):
|
||||
return OrderedDict([(key, create_shared_memory(subspace, n=n))
|
||||
for (key, subspace) in space.spaces.items()])
|
||||
|
||||
|
||||
def read_from_shared_memory(shared_memory, space, n=1):
|
||||
"""Read the batch of observations from shared memory as a numpy array.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shared_memory : dict, tuple, or `multiprocessing.Array` instance
|
||||
Shared object across processes. This contains the observations from the
|
||||
vectorized environment. This object is created with `create_shared_memory`.
|
||||
|
||||
space : `gym.spaces.Space` instance
|
||||
Observation space of a single environment in the vectorized environment.
|
||||
|
||||
n : int
|
||||
Number of environments in the vectorized environment (i.e. the number
|
||||
of processes).
|
||||
|
||||
Returns
|
||||
-------
|
||||
observations : dict, tuple or `np.ndarray` instance
|
||||
Batch of observations as a (possibly nested) numpy array.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The numpy array objects returned by `read_from_shared_memory` shares the
|
||||
memory of `shared_memory`. Any changes to `shared_memory` are forwarded
|
||||
to `observations`, and vice-versa. To avoid any side-effect, use `np.copy`.
|
||||
"""
|
||||
if isinstance(space, _BaseGymSpaces):
|
||||
return read_base_from_shared_memory(shared_memory, space, n=n)
|
||||
elif isinstance(space, Tuple):
|
||||
return read_tuple_from_shared_memory(shared_memory, space, n=n)
|
||||
elif isinstance(space, Dict):
|
||||
return read_dict_from_shared_memory(shared_memory, space, n=n)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def read_base_from_shared_memory(shared_memory, space, n=1):
|
||||
return np.frombuffer(shared_memory.get_obj(),
|
||||
dtype=space.dtype).reshape((n,) + space.shape)
|
||||
|
||||
def read_tuple_from_shared_memory(shared_memory, space, n=1):
|
||||
return tuple(read_from_shared_memory(memory, subspace, n=n)
|
||||
for (memory, subspace) in zip(shared_memory, space.spaces))
|
||||
|
||||
def read_dict_from_shared_memory(shared_memory, space, n=1):
|
||||
return OrderedDict([(key, read_from_shared_memory(memory, subspace, n=n))
|
||||
for ((key, memory), subspace) in zip(shared_memory.items(),
|
||||
space.spaces.values())])
|
||||
|
||||
|
||||
def write_to_shared_memory(index, value, shared_memory, space):
|
||||
"""Write the observation of a single environment into shared memory.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
index : int
|
||||
Index of the environment (must be in `[0, num_envs)`).
|
||||
|
||||
value : sample from `space`
|
||||
Observation of the single environment to write to shared memory.
|
||||
|
||||
shared_memory : dict, tuple, or `multiprocessing.Array` instance
|
||||
Shared object across processes. This contains the observations from the
|
||||
vectorized environment. This object is created with `create_shared_memory`.
|
||||
|
||||
space : `gym.spaces.Space` instance
|
||||
Observation space of a single environment in the vectorized environment.
|
||||
|
||||
Returns
|
||||
-------
|
||||
`None`
|
||||
"""
|
||||
if isinstance(space, _BaseGymSpaces):
|
||||
write_base_to_shared_memory(index, value, shared_memory, space)
|
||||
elif isinstance(space, Tuple):
|
||||
write_tuple_to_shared_memory(index, value, shared_memory, space)
|
||||
elif isinstance(space, Dict):
|
||||
write_dict_to_shared_memory(index, value, shared_memory, space)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def write_base_to_shared_memory(index, value, shared_memory, space):
|
||||
size = int(np.prod(space.shape))
|
||||
shared_memory[index * size:(index + 1) * size] = np.asarray(value,
|
||||
dtype=space.dtype).flatten()
|
||||
|
||||
def write_tuple_to_shared_memory(index, values, shared_memory, space):
|
||||
for value, memory, subspace in zip(values, shared_memory, space.spaces):
|
||||
write_to_shared_memory(index, value, memory, subspace)
|
||||
|
||||
def write_dict_to_shared_memory(index, values, shared_memory, space):
|
||||
for key, value in values.items():
|
||||
write_to_shared_memory(index, value, shared_memory[key], space.spaces[key])
|
70
gym/vector/utils/spaces.py
Normal file
70
gym/vector/utils/spaces.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
|
||||
from gym.spaces import Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict
|
||||
|
||||
_BaseGymSpaces = (Box, Discrete, MultiDiscrete, MultiBinary)
|
||||
__all__ = ['_BaseGymSpaces', 'batch_space']
|
||||
|
||||
def batch_space(space, n=1):
|
||||
"""Create a (batched) space, containing multiple copies of a single space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
space : `gym.spaces.Space` instance
|
||||
Space (e.g. the observation space) for a single environment in the
|
||||
vectorized environment.
|
||||
|
||||
n : int
|
||||
Number of environments in the vectorized environment.
|
||||
|
||||
Returns
|
||||
-------
|
||||
batched_space : `gym.spaces.Space` instance
|
||||
Space (e.g. the observation space) for a batch of environments in the
|
||||
vectorized environment.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> from gym.spaces import Box, Dict
|
||||
>>> space = Dict({
|
||||
... 'position': Box(low=0, high=1, shape=(3,), dtype=np.float32),
|
||||
... 'velocity': Box(low=0, high=1, shape=(2,), dtype=np.float32)})
|
||||
>>> batch_space(space, n=5)
|
||||
Dict(position:Box(5, 3), velocity:Box(5, 2))
|
||||
"""
|
||||
if isinstance(space, _BaseGymSpaces):
|
||||
return batch_space_base(space, n=n)
|
||||
elif isinstance(space, Tuple):
|
||||
return batch_space_tuple(space, n=n)
|
||||
elif isinstance(space, Dict):
|
||||
return batch_space_dict(space, n=n)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def batch_space_base(space, n=1):
|
||||
if isinstance(space, Box):
|
||||
repeats = tuple([n] + [1] * space.low.ndim)
|
||||
low, high = np.tile(space.low, repeats), np.tile(space.high, repeats)
|
||||
return Box(low=low, high=high, dtype=space.dtype)
|
||||
|
||||
elif isinstance(space, Discrete):
|
||||
return MultiDiscrete(np.full((n,), space.n, dtype=space.dtype))
|
||||
|
||||
elif isinstance(space, MultiDiscrete):
|
||||
repeats = tuple([n] + [1] * space.nvec.ndim)
|
||||
high = np.tile(space.nvec, repeats) - 1
|
||||
return Box(low=np.zeros_like(high), high=high, dtype=space.dtype)
|
||||
|
||||
elif isinstance(space, MultiBinary):
|
||||
return Box(low=0, high=1, shape=(n,) + space.shape, dtype=space.dtype)
|
||||
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def batch_space_tuple(space, n=1):
|
||||
return Tuple(tuple(batch_space(subspace, n=n) for subspace in space.spaces))
|
||||
|
||||
def batch_space_dict(space, n=1):
|
||||
return Dict(OrderedDict([(key, batch_space(subspace, n=n))
|
||||
for (key, subspace) in space.spaces.items()]))
|
59
gym/vector/vector_env.py
Normal file
59
gym/vector/vector_env.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import gym
|
||||
from gym.spaces import Tuple
|
||||
from gym.vector.utils.spaces import batch_space
|
||||
|
||||
__all__ = ['VectorEnv']
|
||||
|
||||
|
||||
class VectorEnv(gym.Env):
|
||||
"""Base class for vectorized environments.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_envs : int
|
||||
Number of environments in the vectorized environment.
|
||||
|
||||
observation_space : `gym.spaces.Space` instance
|
||||
Observation space of a single environment.
|
||||
|
||||
action_space : `gym.spaces.Space` instance
|
||||
Action space of a single environment.
|
||||
"""
|
||||
def __init__(self, num_envs, observation_space, action_space):
|
||||
super(VectorEnv, self).__init__()
|
||||
self.num_envs = num_envs
|
||||
self.observation_space = batch_space(observation_space, n=num_envs)
|
||||
self.action_space = Tuple((action_space,) * num_envs)
|
||||
|
||||
self.closed = False
|
||||
self.viewer = None
|
||||
|
||||
# The observation and action spaces of a single environment are
|
||||
# kept in separate properties
|
||||
self.single_observation_space = observation_space
|
||||
self.single_action_space = action_space
|
||||
|
||||
def reset_async(self):
|
||||
pass
|
||||
|
||||
def reset_wait(self, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def reset(self):
|
||||
self.reset_async()
|
||||
return self.reset_wait()
|
||||
|
||||
def step_async(self, actions):
|
||||
pass
|
||||
|
||||
def step_wait(self, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def step(self, actions):
|
||||
self.step_async(actions)
|
||||
return self.step_wait()
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'closed'):
|
||||
if not self.closed:
|
||||
self.close()
|
3
setup.py
3
setup.py
@@ -31,7 +31,8 @@ setup(name='gym',
|
||||
if package.startswith('gym')],
|
||||
zip_safe=False,
|
||||
install_requires=[
|
||||
'scipy', 'numpy>=1.10.4', 'six', 'pyglet>=1.2.0',
|
||||
'scipy', 'numpy>=1.10.4', 'six', 'pyglet>=1.2.0', 'cloudpickle~=1.2.0',
|
||||
'enum34~=1.1.6;python_version<"3.4"'
|
||||
],
|
||||
extras_require=extras,
|
||||
package_data={'gym': [
|
||||
|
Reference in New Issue
Block a user