2021-12-08 22:14:15 +01:00
|
|
|
from typing import Optional, Union, List
|
|
|
|
|
2019-06-21 17:29:44 -04:00
|
|
|
import numpy as np
|
|
|
|
import multiprocessing as mp
|
|
|
|
import time
|
|
|
|
import sys
|
|
|
|
from enum import Enum
|
|
|
|
from copy import deepcopy
|
|
|
|
|
|
|
|
from gym import logger
|
2021-12-08 22:14:15 +01:00
|
|
|
from gym.logger import warn
|
2019-06-21 17:29:44 -04:00
|
|
|
from gym.vector.vector_env import VectorEnv
|
2021-07-29 02:26:34 +02:00
|
|
|
from gym.error import (
|
|
|
|
AlreadyPendingCallError,
|
|
|
|
NoAsyncCallError,
|
|
|
|
ClosedEnvironmentError,
|
|
|
|
CustomSpaceError,
|
|
|
|
)
|
|
|
|
from gym.vector.utils import (
|
|
|
|
create_shared_memory,
|
|
|
|
create_empty_array,
|
|
|
|
write_to_shared_memory,
|
|
|
|
read_from_shared_memory,
|
|
|
|
concatenate,
|
2021-12-08 21:31:41 -05:00
|
|
|
iterate,
|
2021-07-29 02:26:34 +02:00
|
|
|
CloudpickleWrapper,
|
|
|
|
clear_mpi_env_vars,
|
|
|
|
)
|
|
|
|
|
|
|
|
__all__ = ["AsyncVectorEnv"]
|
2019-06-21 17:29:44 -04:00
|
|
|
|
|
|
|
|
|
|
|
class AsyncState(Enum):
|
2021-07-29 02:26:34 +02:00
|
|
|
DEFAULT = "default"
|
|
|
|
WAITING_RESET = "reset"
|
|
|
|
WAITING_STEP = "step"
|
2019-06-21 17:29:44 -04:00
|
|
|
|
|
|
|
|
|
|
|
class AsyncVectorEnv(VectorEnv):
|
|
|
|
"""Vectorized environment that runs multiple environments in parallel. It
|
2021-11-14 08:59:04 -05:00
|
|
|
uses `multiprocessing`_ processes, and pipes for communication.
|
2019-06-21 17:29:44 -04:00
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
env_fns : iterable of callable
|
|
|
|
Functions that create the environments.
|
|
|
|
|
2021-11-14 08:59:04 -05:00
|
|
|
observation_space : :class:`gym.spaces.Space`, optional
|
|
|
|
Observation space of a single environment. If ``None``, then the
|
2019-06-21 17:29:44 -04:00
|
|
|
observation space of the first environment is taken.
|
|
|
|
|
2021-11-14 08:59:04 -05:00
|
|
|
action_space : :class:`gym.spaces.Space`, optional
|
|
|
|
Action space of a single environment. If ``None``, then the action space
|
2019-06-21 17:29:44 -04:00
|
|
|
of the first environment is taken.
|
|
|
|
|
2021-11-14 08:59:04 -05:00
|
|
|
shared_memory : bool
|
|
|
|
If ``True``, then the observations from the worker processes are
|
2019-06-21 17:29:44 -04:00
|
|
|
communicated back through shared variables. This can improve the
|
|
|
|
efficiency if the observations are large (e.g. images).
|
|
|
|
|
2021-11-14 08:59:04 -05:00
|
|
|
copy : bool
|
|
|
|
If ``True``, then the :meth:`~AsyncVectorEnv.reset` and
|
|
|
|
:meth:`~AsyncVectorEnv.step` methods return a copy of the observations.
|
2019-06-21 17:29:44 -04:00
|
|
|
|
|
|
|
context : str, optional
|
2021-11-14 08:59:04 -05:00
|
|
|
Context for `multiprocessing`_. If ``None``, then the default context is used.
|
2019-10-09 15:08:10 -07:00
|
|
|
|
2021-11-14 08:59:04 -05:00
|
|
|
daemon : bool
|
|
|
|
If ``True``, then subprocesses have ``daemon`` flag turned on; that is, they
|
|
|
|
will quit if the head process quits. However, ``daemon=True`` prevents
|
2019-10-09 15:08:10 -07:00
|
|
|
subprocesses to spawn children, so for some environments you may want
|
2021-11-14 08:59:04 -05:00
|
|
|
to have it set to ``False``.
|
|
|
|
|
|
|
|
worker : callable, optional
|
|
|
|
If set, then use that worker in a subprocess instead of a default one.
|
|
|
|
Can be useful to override some inner vector env logic, for instance,
|
|
|
|
how resets on done are handled.
|
|
|
|
|
|
|
|
Warning
|
|
|
|
-------
|
|
|
|
:attr:`worker` is an advanced mode option. It provides a high degree of
|
|
|
|
flexibility and a high chance to shoot yourself in the foot; thus,
|
|
|
|
if you are writing your own worker, it is recommended to start from the code
|
|
|
|
for ``_worker`` (or ``_worker_shared_memory``) method, and add changes.
|
|
|
|
|
|
|
|
Raises
|
|
|
|
------
|
|
|
|
RuntimeError
|
|
|
|
If the observation space of some sub-environment does not match
|
|
|
|
:obj:`observation_space` (or, by default, the observation space of
|
|
|
|
the first sub-environment).
|
|
|
|
|
|
|
|
ValueError
|
|
|
|
If :obj:`observation_space` is a custom space (i.e. not a default
|
|
|
|
space in Gym, such as :class:`~gym.spaces.Box`, :class:`~gym.spaces.Discrete`,
|
|
|
|
or :class:`~gym.spaces.Dict`) and :obj:`shared_memory` is ``True``.
|
|
|
|
|
|
|
|
Example
|
|
|
|
-------
|
|
|
|
|
|
|
|
.. code-block::
|
|
|
|
|
|
|
|
>>> env = gym.vector.AsyncVectorEnv([
|
|
|
|
... lambda: gym.make("Pendulum-v0", g=9.81),
|
|
|
|
... lambda: gym.make("Pendulum-v0", g=1.62)
|
|
|
|
... ])
|
|
|
|
>>> env.reset()
|
|
|
|
array([[-0.8286432 , 0.5597771 , 0.90249056],
|
|
|
|
[-0.85009176, 0.5266346 , 0.60007906]], dtype=float32)
|
2019-06-21 17:29:44 -04:00
|
|
|
"""
|
2021-07-29 02:26:34 +02:00
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
env_fns,
|
|
|
|
observation_space=None,
|
|
|
|
action_space=None,
|
|
|
|
shared_memory=True,
|
|
|
|
copy=True,
|
|
|
|
context=None,
|
|
|
|
daemon=True,
|
|
|
|
worker=None,
|
|
|
|
):
|
2020-09-12 00:07:27 +02:00
|
|
|
ctx = mp.get_context(context)
|
2019-06-21 17:29:44 -04:00
|
|
|
self.env_fns = env_fns
|
|
|
|
self.shared_memory = shared_memory
|
|
|
|
self.copy = copy
|
2021-08-18 16:36:40 -04:00
|
|
|
dummy_env = env_fns[0]()
|
|
|
|
self.metadata = dummy_env.metadata
|
2019-06-21 17:29:44 -04:00
|
|
|
|
|
|
|
if (observation_space is None) or (action_space is None):
|
|
|
|
observation_space = observation_space or dummy_env.observation_space
|
|
|
|
action_space = action_space or dummy_env.action_space
|
2021-08-18 16:36:40 -04:00
|
|
|
dummy_env.close()
|
|
|
|
del dummy_env
|
2021-11-14 14:51:32 +01:00
|
|
|
super().__init__(
|
2021-07-29 02:26:34 +02:00
|
|
|
num_envs=len(env_fns),
|
|
|
|
observation_space=observation_space,
|
|
|
|
action_space=action_space,
|
|
|
|
)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
|
|
|
if self.shared_memory:
|
2020-09-21 22:38:51 +02:00
|
|
|
try:
|
2021-07-29 15:39:42 -04:00
|
|
|
_obs_buffer = create_shared_memory(
|
|
|
|
self.single_observation_space, n=self.num_envs, ctx=ctx
|
|
|
|
)
|
|
|
|
self.observations = read_from_shared_memory(
|
|
|
|
_obs_buffer, self.single_observation_space, n=self.num_envs
|
|
|
|
)
|
2020-09-21 22:38:51 +02:00
|
|
|
except CustomSpaceError:
|
2021-07-29 02:26:34 +02:00
|
|
|
raise ValueError(
|
|
|
|
"Using `shared_memory=True` in `AsyncVectorEnv` "
|
|
|
|
"is incompatible with non-standard Gym observation spaces "
|
|
|
|
"(i.e. custom spaces inheriting from `gym.Space`), and is "
|
|
|
|
"only compatible with default Gym spaces (e.g. `Box`, "
|
|
|
|
"`Tuple`, `Dict`) for batching. Set `shared_memory=False` "
|
|
|
|
"if you use custom observation spaces."
|
|
|
|
)
|
2019-06-21 17:29:44 -04:00
|
|
|
else:
|
|
|
|
_obs_buffer = None
|
2021-07-29 15:39:42 -04:00
|
|
|
self.observations = create_empty_array(
|
|
|
|
self.single_observation_space, n=self.num_envs, fn=np.zeros
|
|
|
|
)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
|
|
|
self.parent_pipes, self.processes = [], []
|
|
|
|
self.error_queue = ctx.Queue()
|
|
|
|
target = _worker_shared_memory if self.shared_memory else _worker
|
2019-10-09 15:08:10 -07:00
|
|
|
target = worker or target
|
2019-06-21 17:29:44 -04:00
|
|
|
with clear_mpi_env_vars():
|
|
|
|
for idx, env_fn in enumerate(self.env_fns):
|
|
|
|
parent_pipe, child_pipe = ctx.Pipe()
|
2021-07-29 02:26:34 +02:00
|
|
|
process = ctx.Process(
|
|
|
|
target=target,
|
2021-11-14 14:51:32 +01:00
|
|
|
name=f"Worker<{type(self).__name__}>-{idx}",
|
2021-07-29 02:26:34 +02:00
|
|
|
args=(
|
|
|
|
idx,
|
|
|
|
CloudpickleWrapper(env_fn),
|
|
|
|
child_pipe,
|
|
|
|
parent_pipe,
|
|
|
|
_obs_buffer,
|
|
|
|
self.error_queue,
|
|
|
|
),
|
|
|
|
)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
|
|
|
self.parent_pipes.append(parent_pipe)
|
|
|
|
self.processes.append(process)
|
|
|
|
|
2019-10-09 15:08:10 -07:00
|
|
|
process.daemon = daemon
|
2019-06-21 17:29:44 -04:00
|
|
|
process.start()
|
|
|
|
child_pipe.close()
|
|
|
|
|
|
|
|
self._state = AsyncState.DEFAULT
|
2021-12-08 21:31:41 -05:00
|
|
|
self._check_spaces()
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2021-12-08 22:14:15 +01:00
|
|
|
def seed(self, seed=None):
|
|
|
|
super().seed(seed=seed)
|
2019-06-21 17:29:44 -04:00
|
|
|
self._assert_is_running()
|
2021-12-08 22:14:15 +01:00
|
|
|
if seed is None:
|
|
|
|
seed = [None for _ in range(self.num_envs)]
|
|
|
|
if isinstance(seed, int):
|
|
|
|
seed = [seed + i for i in range(self.num_envs)]
|
|
|
|
assert len(seed) == self.num_envs
|
2019-06-21 17:29:44 -04:00
|
|
|
|
|
|
|
if self._state != AsyncState.DEFAULT:
|
2021-07-29 02:26:34 +02:00
|
|
|
raise AlreadyPendingCallError(
|
2021-11-14 14:51:32 +01:00
|
|
|
f"Calling `seed` while waiting for a pending call to `{self._state.value}` to complete.",
|
2021-07-29 02:26:34 +02:00
|
|
|
self._state.value,
|
|
|
|
)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2021-12-08 22:14:15 +01:00
|
|
|
for pipe, seed in zip(self.parent_pipes, seed):
|
2021-07-29 02:26:34 +02:00
|
|
|
pipe.send(("seed", seed))
|
2019-06-28 17:42:21 -04:00
|
|
|
_, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
|
|
|
self._raise_if_errors(successes)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2022-01-19 23:28:59 +01:00
|
|
|
def reset_async(
|
|
|
|
self,
|
|
|
|
seed: Optional[Union[int, List[int]]] = None,
|
|
|
|
options: Optional[dict] = None,
|
|
|
|
):
|
2021-11-14 08:59:04 -05:00
|
|
|
"""Send the calls to :obj:`reset` to each sub-environment.
|
|
|
|
|
|
|
|
Raises
|
|
|
|
------
|
|
|
|
ClosedEnvironmentError
|
|
|
|
If the environment was closed (if :meth:`close` was previously called).
|
|
|
|
|
|
|
|
AlreadyPendingCallError
|
|
|
|
If the environment is already waiting for a pending call to another
|
|
|
|
method (e.g. :meth:`step_async`). This can be caused by two consecutive
|
|
|
|
calls to :meth:`reset_async`, with no call to :meth:`reset_wait` in
|
|
|
|
between.
|
|
|
|
"""
|
2019-06-21 17:29:44 -04:00
|
|
|
self._assert_is_running()
|
2021-12-08 22:14:15 +01:00
|
|
|
|
|
|
|
if seed is None:
|
|
|
|
seed = [None for _ in range(self.num_envs)]
|
|
|
|
if isinstance(seed, int):
|
|
|
|
seed = [seed + i for i in range(self.num_envs)]
|
|
|
|
assert len(seed) == self.num_envs
|
|
|
|
|
2019-06-21 17:29:44 -04:00
|
|
|
if self._state != AsyncState.DEFAULT:
|
2021-07-29 02:26:34 +02:00
|
|
|
raise AlreadyPendingCallError(
|
2021-11-14 14:51:32 +01:00
|
|
|
f"Calling `reset_async` while waiting for a pending call to `{self._state.value}` to complete",
|
2021-07-29 02:26:34 +02:00
|
|
|
self._state.value,
|
|
|
|
)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2021-12-08 22:14:15 +01:00
|
|
|
for pipe, single_seed in zip(self.parent_pipes, seed):
|
2022-01-19 23:28:59 +01:00
|
|
|
single_kwargs = {}
|
|
|
|
if single_seed is not None:
|
|
|
|
single_kwargs["seed"] = single_seed
|
|
|
|
if options is not None:
|
|
|
|
single_kwargs["options"] = options
|
|
|
|
|
|
|
|
pipe.send(("reset", single_kwargs))
|
2019-06-21 17:29:44 -04:00
|
|
|
self._state = AsyncState.WAITING_RESET
|
|
|
|
|
2022-01-19 23:28:59 +01:00
|
|
|
def reset_wait(
|
|
|
|
self, timeout=None, seed: Optional[int] = None, options: Optional[dict] = None
|
|
|
|
):
|
2021-12-08 22:14:15 +01:00
|
|
|
"""
|
2019-06-21 17:29:44 -04:00
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
timeout : int or float, optional
|
2021-12-08 22:14:15 +01:00
|
|
|
Number of seconds before the call to `reset_wait` times out. If
|
|
|
|
`None`, the call to `reset_wait` never times out.
|
|
|
|
seed: ignored
|
2022-01-19 23:28:59 +01:00
|
|
|
options: ignored
|
2019-06-21 17:29:44 -04:00
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
2021-11-14 08:59:04 -05:00
|
|
|
element of :attr:`~VectorEnv.observation_space`
|
2019-06-21 17:29:44 -04:00
|
|
|
A batch of observations from the vectorized environment.
|
2021-11-14 08:59:04 -05:00
|
|
|
|
|
|
|
Raises
|
|
|
|
------
|
|
|
|
ClosedEnvironmentError
|
|
|
|
If the environment was closed (if :meth:`close` was previously called).
|
|
|
|
|
|
|
|
NoAsyncCallError
|
|
|
|
If :meth:`reset_wait` was called without any prior call to
|
|
|
|
:meth:`reset_async`.
|
|
|
|
|
|
|
|
TimeoutError
|
|
|
|
If :meth:`reset_wait` timed out.
|
2019-06-21 17:29:44 -04:00
|
|
|
"""
|
|
|
|
self._assert_is_running()
|
|
|
|
if self._state != AsyncState.WAITING_RESET:
|
2021-07-29 02:26:34 +02:00
|
|
|
raise NoAsyncCallError(
|
|
|
|
"Calling `reset_wait` without any prior " "call to `reset_async`.",
|
|
|
|
AsyncState.WAITING_RESET.value,
|
|
|
|
)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
|
|
|
if not self._poll(timeout):
|
|
|
|
self._state = AsyncState.DEFAULT
|
2021-07-29 02:26:34 +02:00
|
|
|
raise mp.TimeoutError(
|
2021-11-14 14:51:32 +01:00
|
|
|
f"The call to `reset_wait` has timed out after {timeout} second{'s' if timeout > 1 else ''}."
|
2021-07-29 02:26:34 +02:00
|
|
|
)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2019-06-28 17:42:21 -04:00
|
|
|
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
|
|
|
self._raise_if_errors(successes)
|
2019-06-21 17:29:44 -04:00
|
|
|
self._state = AsyncState.DEFAULT
|
|
|
|
|
|
|
|
if not self.shared_memory:
|
2021-07-29 15:39:42 -04:00
|
|
|
self.observations = concatenate(
|
|
|
|
results, self.observations, self.single_observation_space
|
|
|
|
)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
|
|
|
return deepcopy(self.observations) if self.copy else self.observations
|
|
|
|
|
|
|
|
def step_async(self, actions):
|
2021-11-14 08:59:04 -05:00
|
|
|
"""Send the calls to :obj:`step` to each sub-environment.
|
|
|
|
|
2019-06-21 17:29:44 -04:00
|
|
|
Parameters
|
|
|
|
----------
|
2021-11-14 08:59:04 -05:00
|
|
|
actions : element of :attr:`~VectorEnv.action_space`
|
|
|
|
Batch of actions.
|
|
|
|
|
|
|
|
Raises
|
|
|
|
------
|
|
|
|
ClosedEnvironmentError
|
|
|
|
If the environment was closed (if :meth:`close` was previously called).
|
|
|
|
|
|
|
|
AlreadyPendingCallError
|
|
|
|
If the environment is already waiting for a pending call to another
|
|
|
|
method (e.g. :meth:`reset_async`). This can be caused by two consecutive
|
|
|
|
calls to :meth:`step_async`, with no call to :meth:`step_wait` in
|
|
|
|
between.
|
2019-06-21 17:29:44 -04:00
|
|
|
"""
|
|
|
|
self._assert_is_running()
|
|
|
|
if self._state != AsyncState.DEFAULT:
|
2021-07-29 02:26:34 +02:00
|
|
|
raise AlreadyPendingCallError(
|
2021-11-14 14:51:32 +01:00
|
|
|
f"Calling `step_async` while waiting for a pending call to `{self._state.value}` to complete.",
|
2021-07-29 02:26:34 +02:00
|
|
|
self._state.value,
|
|
|
|
)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2021-12-08 21:31:41 -05:00
|
|
|
actions = iterate(self.action_space, actions)
|
2019-06-21 17:29:44 -04:00
|
|
|
for pipe, action in zip(self.parent_pipes, actions):
|
2021-07-29 02:26:34 +02:00
|
|
|
pipe.send(("step", action))
|
2019-06-21 17:29:44 -04:00
|
|
|
self._state = AsyncState.WAITING_STEP
|
|
|
|
|
|
|
|
def step_wait(self, timeout=None):
|
2021-11-14 08:59:04 -05:00
|
|
|
"""Wait for the calls to :obj:`step` in each sub-environment to finish.
|
|
|
|
|
2019-06-21 17:29:44 -04:00
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
timeout : int or float, optional
|
2021-11-14 08:59:04 -05:00
|
|
|
Number of seconds before the call to :meth:`step_wait` times out. If
|
|
|
|
``None``, the call to :meth:`step_wait` never times out.
|
2019-06-21 17:29:44 -04:00
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
2021-11-14 08:59:04 -05:00
|
|
|
observations : element of :attr:`~VectorEnv.observation_space`
|
2019-06-21 17:29:44 -04:00
|
|
|
A batch of observations from the vectorized environment.
|
|
|
|
|
2021-11-14 08:59:04 -05:00
|
|
|
rewards : :obj:`np.ndarray`, dtype :obj:`np.float_`
|
2019-06-21 17:29:44 -04:00
|
|
|
A vector of rewards from the vectorized environment.
|
|
|
|
|
2021-11-14 08:59:04 -05:00
|
|
|
dones : :obj:`np.ndarray`, dtype :obj:`np.bool_`
|
2019-06-21 17:29:44 -04:00
|
|
|
A vector whose entries indicate whether the episode has ended.
|
|
|
|
|
|
|
|
infos : list of dict
|
2021-11-14 08:59:04 -05:00
|
|
|
A list of auxiliary diagnostic information dicts from sub-environments.
|
|
|
|
|
|
|
|
Raises
|
|
|
|
------
|
|
|
|
ClosedEnvironmentError
|
|
|
|
If the environment was closed (if :meth:`close` was previously called).
|
|
|
|
|
|
|
|
NoAsyncCallError
|
|
|
|
If :meth:`step_wait` was called without any prior call to
|
|
|
|
:meth:`step_async`.
|
|
|
|
|
|
|
|
TimeoutError
|
|
|
|
If :meth:`step_wait` timed out.
|
2019-06-21 17:29:44 -04:00
|
|
|
"""
|
|
|
|
self._assert_is_running()
|
|
|
|
if self._state != AsyncState.WAITING_STEP:
|
2021-07-29 02:26:34 +02:00
|
|
|
raise NoAsyncCallError(
|
|
|
|
"Calling `step_wait` without any prior call " "to `step_async`.",
|
|
|
|
AsyncState.WAITING_STEP.value,
|
|
|
|
)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
|
|
|
if not self._poll(timeout):
|
|
|
|
self._state = AsyncState.DEFAULT
|
2021-07-29 02:26:34 +02:00
|
|
|
raise mp.TimeoutError(
|
2021-11-14 14:51:32 +01:00
|
|
|
f"The call to `step_wait` has timed out after {timeout} second{'s' if timeout > 1 else ''}."
|
2021-07-29 02:26:34 +02:00
|
|
|
)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2019-06-28 17:42:21 -04:00
|
|
|
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
|
|
|
self._raise_if_errors(successes)
|
2019-06-21 17:29:44 -04:00
|
|
|
self._state = AsyncState.DEFAULT
|
|
|
|
observations_list, rewards, dones, infos = zip(*results)
|
|
|
|
|
|
|
|
if not self.shared_memory:
|
2021-07-29 15:39:42 -04:00
|
|
|
self.observations = concatenate(
|
|
|
|
observations_list, self.observations, self.single_observation_space
|
|
|
|
)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2021-07-29 02:26:34 +02:00
|
|
|
return (
|
|
|
|
deepcopy(self.observations) if self.copy else self.observations,
|
|
|
|
np.array(rewards),
|
|
|
|
np.array(dones, dtype=np.bool_),
|
|
|
|
infos,
|
|
|
|
)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2019-10-26 00:18:54 +02:00
|
|
|
def close_extras(self, timeout=None, terminate=False):
|
2021-11-14 08:59:04 -05:00
|
|
|
"""Close the environments & clean up the extra resources
|
|
|
|
(processes and pipes).
|
|
|
|
|
2019-06-21 17:29:44 -04:00
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
timeout : int or float, optional
|
2021-11-14 08:59:04 -05:00
|
|
|
Number of seconds before the call to :meth:`close` times out. If ``None``,
|
|
|
|
the call to :meth:`close` never times out. If the call to :meth:`close`
|
|
|
|
times out, then all processes are terminated.
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2021-11-14 08:59:04 -05:00
|
|
|
terminate : bool
|
|
|
|
If ``True``, then the :meth:`close` operation is forced and all processes
|
2019-06-21 17:29:44 -04:00
|
|
|
are terminated.
|
2021-11-14 08:59:04 -05:00
|
|
|
|
|
|
|
Raises
|
|
|
|
------
|
|
|
|
TimeoutError
|
|
|
|
If :meth:`close` timed out.
|
2019-06-21 17:29:44 -04:00
|
|
|
"""
|
|
|
|
timeout = 0 if terminate else timeout
|
|
|
|
try:
|
|
|
|
if self._state != AsyncState.DEFAULT:
|
2021-07-29 02:26:34 +02:00
|
|
|
logger.warn(
|
2021-11-14 14:51:32 +01:00
|
|
|
f"Calling `close` while waiting for a pending call to `{self._state.value}` to complete."
|
2021-07-29 02:26:34 +02:00
|
|
|
)
|
2021-11-14 14:51:32 +01:00
|
|
|
function = getattr(self, f"{self._state.value}_wait")
|
2019-06-21 17:29:44 -04:00
|
|
|
function(timeout)
|
|
|
|
except mp.TimeoutError:
|
|
|
|
terminate = True
|
|
|
|
|
|
|
|
if terminate:
|
|
|
|
for process in self.processes:
|
|
|
|
if process.is_alive():
|
|
|
|
process.terminate()
|
|
|
|
else:
|
|
|
|
for pipe in self.parent_pipes:
|
2019-06-28 18:23:25 -04:00
|
|
|
if (pipe is not None) and (not pipe.closed):
|
2021-07-29 02:26:34 +02:00
|
|
|
pipe.send(("close", None))
|
2019-06-21 17:29:44 -04:00
|
|
|
for pipe in self.parent_pipes:
|
2019-06-28 18:23:25 -04:00
|
|
|
if (pipe is not None) and (not pipe.closed):
|
2019-06-21 17:29:44 -04:00
|
|
|
pipe.recv()
|
|
|
|
|
|
|
|
for pipe in self.parent_pipes:
|
2019-06-28 18:23:25 -04:00
|
|
|
if pipe is not None:
|
|
|
|
pipe.close()
|
2019-06-21 17:29:44 -04:00
|
|
|
for process in self.processes:
|
|
|
|
process.join()
|
|
|
|
|
|
|
|
def _poll(self, timeout=None):
|
|
|
|
self._assert_is_running()
|
2019-06-23 13:51:18 -04:00
|
|
|
if timeout is None:
|
|
|
|
return True
|
2021-09-12 02:03:54 +09:00
|
|
|
end_time = time.perf_counter() + timeout
|
2019-06-21 17:29:44 -04:00
|
|
|
delta = None
|
|
|
|
for pipe in self.parent_pipes:
|
2021-09-12 02:03:54 +09:00
|
|
|
delta = max(end_time - time.perf_counter(), 0)
|
2019-06-23 15:36:59 -04:00
|
|
|
if pipe is None:
|
|
|
|
return False
|
2019-06-21 17:29:44 -04:00
|
|
|
if pipe.closed or (not pipe.poll(delta)):
|
2019-06-23 13:51:18 -04:00
|
|
|
return False
|
|
|
|
return True
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2021-12-08 21:31:41 -05:00
|
|
|
def _check_spaces(self):
|
2019-06-21 17:29:44 -04:00
|
|
|
self._assert_is_running()
|
2021-12-08 21:31:41 -05:00
|
|
|
spaces = (self.single_observation_space, self.single_action_space)
|
2019-06-21 17:29:44 -04:00
|
|
|
for pipe in self.parent_pipes:
|
2021-12-08 21:31:41 -05:00
|
|
|
pipe.send(("_check_spaces", spaces))
|
|
|
|
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
|
2019-06-28 17:42:21 -04:00
|
|
|
self._raise_if_errors(successes)
|
2021-12-08 21:31:41 -05:00
|
|
|
same_observation_spaces, same_action_spaces = zip(*results)
|
|
|
|
if not all(same_observation_spaces):
|
|
|
|
raise RuntimeError(
|
|
|
|
"Some environments have an observation space different from "
|
|
|
|
f"`{self.single_observation_space}`. In order to batch observations, "
|
|
|
|
"the observation spaces from all environments must be equal."
|
|
|
|
)
|
|
|
|
if not all(same_action_spaces):
|
2021-07-29 02:26:34 +02:00
|
|
|
raise RuntimeError(
|
2021-12-08 21:31:41 -05:00
|
|
|
"Some environments have an action space different from "
|
|
|
|
f"`{self.single_action_space}`. In order to batch actions, the "
|
|
|
|
"action spaces from all environments must be equal."
|
2021-07-29 02:26:34 +02:00
|
|
|
)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
|
|
|
def _assert_is_running(self):
|
|
|
|
if self.closed:
|
2021-07-29 02:26:34 +02:00
|
|
|
raise ClosedEnvironmentError(
|
2021-11-14 14:51:32 +01:00
|
|
|
f"Trying to operate on `{type(self).__name__}`, after a call to `close()`."
|
2021-07-29 02:26:34 +02:00
|
|
|
)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2019-06-28 17:42:21 -04:00
|
|
|
def _raise_if_errors(self, successes):
|
|
|
|
if all(successes):
|
|
|
|
return
|
|
|
|
|
|
|
|
num_errors = self.num_envs - sum(successes)
|
|
|
|
assert num_errors > 0
|
|
|
|
for _ in range(num_errors):
|
|
|
|
index, exctype, value = self.error_queue.get()
|
2021-07-29 15:39:42 -04:00
|
|
|
logger.error(
|
2021-11-14 14:51:32 +01:00
|
|
|
f"Received the following error from Worker-{index}: {exctype.__name__}: {value}"
|
2021-07-29 15:39:42 -04:00
|
|
|
)
|
2021-11-14 14:51:32 +01:00
|
|
|
logger.error(f"Shutting down Worker-{index}.")
|
2019-06-28 17:42:21 -04:00
|
|
|
self.parent_pipes[index].close()
|
|
|
|
self.parent_pipes[index] = None
|
|
|
|
|
2021-07-29 02:26:34 +02:00
|
|
|
logger.error("Raising the last exception back to the main process.")
|
2019-06-28 17:42:21 -04:00
|
|
|
raise exctype(value)
|
2019-06-21 17:29:44 -04:00
|
|
|
|
2021-12-08 19:55:09 -05:00
|
|
|
def __del__(self):
|
|
|
|
if not getattr(self, "closed", True):
|
|
|
|
self.close(terminate=True)
|
|
|
|
|
2019-06-21 17:29:44 -04:00
|
|
|
|
|
|
|
def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
|
|
|
assert shared_memory is None
|
|
|
|
env = env_fn()
|
|
|
|
parent_pipe.close()
|
|
|
|
try:
|
|
|
|
while True:
|
|
|
|
command, data = pipe.recv()
|
2021-07-29 02:26:34 +02:00
|
|
|
if command == "reset":
|
2022-01-19 23:28:59 +01:00
|
|
|
observation = env.reset(**data)
|
2019-06-28 17:42:21 -04:00
|
|
|
pipe.send((observation, True))
|
2021-07-29 02:26:34 +02:00
|
|
|
elif command == "step":
|
2019-06-21 17:29:44 -04:00
|
|
|
observation, reward, done, info = env.step(data)
|
|
|
|
if done:
|
2021-11-14 08:57:44 -05:00
|
|
|
info["terminal_observation"] = observation
|
2019-06-21 17:29:44 -04:00
|
|
|
observation = env.reset()
|
2019-06-28 17:42:21 -04:00
|
|
|
pipe.send(((observation, reward, done, info), True))
|
2021-07-29 02:26:34 +02:00
|
|
|
elif command == "seed":
|
2019-06-21 17:29:44 -04:00
|
|
|
env.seed(data)
|
2019-06-28 17:42:21 -04:00
|
|
|
pipe.send((None, True))
|
2021-07-29 02:26:34 +02:00
|
|
|
elif command == "close":
|
2019-06-28 17:42:21 -04:00
|
|
|
pipe.send((None, True))
|
2019-06-21 17:29:44 -04:00
|
|
|
break
|
2021-12-08 21:31:41 -05:00
|
|
|
elif command == "_check_spaces":
|
|
|
|
pipe.send(
|
|
|
|
(
|
|
|
|
(data[0] == env.observation_space, data[1] == env.action_space),
|
|
|
|
True,
|
|
|
|
)
|
|
|
|
)
|
2019-06-21 17:29:44 -04:00
|
|
|
else:
|
2021-07-29 02:26:34 +02:00
|
|
|
raise RuntimeError(
|
|
|
|
"Received unknown command `{0}`. Must "
|
|
|
|
"be one of {`reset`, `step`, `seed`, `close`, "
|
2021-12-08 21:31:41 -05:00
|
|
|
"`_check_spaces`}.".format(command)
|
2021-07-29 02:26:34 +02:00
|
|
|
)
|
2019-06-21 22:14:29 -04:00
|
|
|
except (KeyboardInterrupt, Exception):
|
2019-06-21 17:29:44 -04:00
|
|
|
error_queue.put((index,) + sys.exc_info()[:2])
|
2019-06-28 17:42:21 -04:00
|
|
|
pipe.send((None, False))
|
2019-06-21 17:29:44 -04:00
|
|
|
finally:
|
|
|
|
env.close()
|
|
|
|
|
|
|
|
|
|
|
|
def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
|
|
|
assert shared_memory is not None
|
|
|
|
env = env_fn()
|
|
|
|
observation_space = env.observation_space
|
|
|
|
parent_pipe.close()
|
|
|
|
try:
|
|
|
|
while True:
|
|
|
|
command, data = pipe.recv()
|
2021-07-29 02:26:34 +02:00
|
|
|
if command == "reset":
|
2022-01-19 23:28:59 +01:00
|
|
|
observation = env.reset(**data)
|
2021-07-29 15:39:42 -04:00
|
|
|
write_to_shared_memory(
|
|
|
|
index, observation, shared_memory, observation_space
|
|
|
|
)
|
2019-06-28 17:42:21 -04:00
|
|
|
pipe.send((None, True))
|
2021-07-29 02:26:34 +02:00
|
|
|
elif command == "step":
|
2019-06-21 17:29:44 -04:00
|
|
|
observation, reward, done, info = env.step(data)
|
|
|
|
if done:
|
2021-11-14 08:57:44 -05:00
|
|
|
info["terminal_observation"] = observation
|
2019-06-21 17:29:44 -04:00
|
|
|
observation = env.reset()
|
2021-07-29 15:39:42 -04:00
|
|
|
write_to_shared_memory(
|
|
|
|
index, observation, shared_memory, observation_space
|
|
|
|
)
|
2019-06-28 17:42:21 -04:00
|
|
|
pipe.send(((None, reward, done, info), True))
|
2021-07-29 02:26:34 +02:00
|
|
|
elif command == "seed":
|
2019-06-21 17:29:44 -04:00
|
|
|
env.seed(data)
|
2019-06-28 17:42:21 -04:00
|
|
|
pipe.send((None, True))
|
2021-07-29 02:26:34 +02:00
|
|
|
elif command == "close":
|
2019-06-28 17:42:21 -04:00
|
|
|
pipe.send((None, True))
|
2019-06-21 17:29:44 -04:00
|
|
|
break
|
2021-12-08 21:31:41 -05:00
|
|
|
elif command == "_check_spaces":
|
|
|
|
pipe.send(
|
|
|
|
((data[0] == observation_space, data[1] == env.action_space), True)
|
|
|
|
)
|
2019-06-21 17:29:44 -04:00
|
|
|
else:
|
2021-07-29 02:26:34 +02:00
|
|
|
raise RuntimeError(
|
|
|
|
"Received unknown command `{0}`. Must "
|
|
|
|
"be one of {`reset`, `step`, `seed`, `close`, "
|
2021-12-08 21:31:41 -05:00
|
|
|
"`_check_spaces`}.".format(command)
|
2021-07-29 02:26:34 +02:00
|
|
|
)
|
2019-06-21 22:14:29 -04:00
|
|
|
except (KeyboardInterrupt, Exception):
|
2019-06-21 17:29:44 -04:00
|
|
|
error_queue.put((index,) + sys.exc_info()[:2])
|
2019-06-28 17:42:21 -04:00
|
|
|
pipe.send((None, False))
|
2019-06-21 17:29:44 -04:00
|
|
|
finally:
|
|
|
|
env.close()
|