diff --git a/baselines/common/tests/test_with_mpi.py b/baselines/common/tests/test_with_mpi.py index 0d71a64..3178a5a 100644 --- a/baselines/common/tests/test_with_mpi.py +++ b/baselines/common/tests/test_with_mpi.py @@ -7,7 +7,7 @@ import pytest from mpi4py import MPI -def test_with_mpi(nproc=2, timeout=5, skip_if_no_mpi=True): +def test_with_mpi(nproc=2, timeout=10, skip_if_no_mpi=True): def outer_thunk(fn): def thunk(*args, **kwargs): serialized_fn = base64.b64encode(cloudpickle.dumps(lambda: fn(*args, **kwargs))) diff --git a/baselines/common/vec_env/__init__.py b/baselines/common/vec_env/__init__.py index 9c3703e..2817fad 100644 --- a/baselines/common/vec_env/__init__.py +++ b/baselines/common/vec_env/__init__.py @@ -1,6 +1,10 @@ +import contextlib +import os from abc import ABC, abstractmethod + from baselines.common.tile_images import tile_images + class AlreadySteppingError(Exception): """ Raised when an asynchronous step is running while @@ -181,6 +185,7 @@ class VecEnvObservationWrapper(VecEnvWrapper): obs, rews, dones, infos = self.venv.step_wait() return self.process(obs), rews, dones, infos + class CloudpickleWrapper(object): """ Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) @@ -196,3 +201,23 @@ class CloudpickleWrapper(object): def __setstate__(self, ob): import pickle self.x = pickle.loads(ob) + + +@contextlib.contextmanager +def clear_mpi_env_vars(): + """ + from mpi4py import MPI will call MPI_Init by default. If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang. + + This context manager is a hacky way to clear those environment variables temporarily such as when we are starting multiprocessing + Processes. + """ + removed_environment = {} + for k, v in list(os.environ.items()): + for prefix in ['OMPI_', 'PMI_']: + if k.startswith(prefix): + removed_environment[k] = v + del os.environ[k] + try: + yield + finally: + os.environ.update(removed_environment) \ No newline at end of file diff --git a/baselines/common/vec_env/shmem_vec_env.py b/baselines/common/vec_env/shmem_vec_env.py index fcdcf47..986d078 100644 --- a/baselines/common/vec_env/shmem_vec_env.py +++ b/baselines/common/vec_env/shmem_vec_env.py @@ -2,9 +2,9 @@ An interface for asynchronous vectorized environments. """ -from multiprocessing import Pipe, Array, Process +import multiprocessing as mp import numpy as np -from . import VecEnv, CloudpickleWrapper +from . import VecEnv, CloudpickleWrapper, clear_mpi_env_vars import ctypes from baselines import logger @@ -16,6 +16,8 @@ _NP_TO_CT = {np.float32: ctypes.c_float, np.uint8: ctypes.c_char, np.bool: ctypes.c_bool} +ctx = mp.get_context('spawn') + class ShmemVecEnv(VecEnv): """ @@ -39,20 +41,21 @@ class ShmemVecEnv(VecEnv): VecEnv.__init__(self, len(env_fns), observation_space, action_space) self.obs_keys, self.obs_shapes, self.obs_dtypes = obs_space_info(observation_space) self.obs_bufs = [ - {k: Array(_NP_TO_CT[self.obs_dtypes[k].type], int(np.prod(self.obs_shapes[k]))) for k in self.obs_keys} + {k: ctx.Array(_NP_TO_CT[self.obs_dtypes[k].type], int(np.prod(self.obs_shapes[k]))) for k in self.obs_keys} for _ in env_fns] self.parent_pipes = [] self.procs = [] - for env_fn, obs_buf in zip(env_fns, self.obs_bufs): - wrapped_fn = CloudpickleWrapper(env_fn) - parent_pipe, child_pipe = Pipe() - proc = Process(target=_subproc_worker, - args=(child_pipe, parent_pipe, wrapped_fn, obs_buf, self.obs_shapes, self.obs_dtypes, self.obs_keys)) - proc.daemon = True - self.procs.append(proc) - self.parent_pipes.append(parent_pipe) - proc.start() - child_pipe.close() + with clear_mpi_env_vars(): + for env_fn, obs_buf in zip(env_fns, self.obs_bufs): + wrapped_fn = CloudpickleWrapper(env_fn) + parent_pipe, child_pipe = ctx.Pipe() + proc = ctx.Process(target=_subproc_worker, + args=(child_pipe, parent_pipe, wrapped_fn, obs_buf, self.obs_shapes, self.obs_dtypes, self.obs_keys)) + proc.daemon = True + self.procs.append(proc) + self.parent_pipes.append(parent_pipe) + proc.start() + child_pipe.close() self.waiting_step = False self.viewer = None diff --git a/baselines/common/vec_env/subproc_vec_env.py b/baselines/common/vec_env/subproc_vec_env.py index 2d72508..72cdb93 100644 --- a/baselines/common/vec_env/subproc_vec_env.py +++ b/baselines/common/vec_env/subproc_vec_env.py @@ -1,28 +1,7 @@ -import contextlib import multiprocessing as mp -import os import numpy as np -from . import VecEnv, CloudpickleWrapper - -@contextlib.contextmanager -def clear_mpi_env_vars(): - """ - from mpi4py import MPI will call MPI_Init by default. If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang. - - This context manager is a hacky way to clear those environment variables temporarily such as when we are starting multiprocessing - Processes. - """ - removed_environment = {} - for k, v in list(os.environ.items()): - for prefix in ['OMPI_', 'PMI_']: - if k.startswith(prefix): - removed_environment[k] = v - del os.environ[k] - try: - yield - finally: - os.environ.update(removed_environment) +from . import VecEnv, CloudpickleWrapper, clear_mpi_env_vars ctx = mp.get_context('spawn')