From 8c547e597310eea21766be9fa94bd6b9c3056051 Mon Sep 17 00:00:00 2001 From: Christopher Hesse Date: Tue, 15 Jan 2019 13:05:04 -0800 Subject: [PATCH] use spawn for shmem vec env as well (#2) (#219) * lazy_mpi load * cleanups * more lazy mpi * don't pretend that class is a module, just use it as a class * mass-replace mpi4py imports * flake8 * fix previous lazy_mpi imports * silly recursion * try os.environ hack * better prefix test, work with mpich * restored MPI imports * removed commented import in test_with_mpi * restored codegen from master * remove lazy mpi * restored changes from rl-algs * remove extra files * port mpi fix to shmem vec env * increase the mpi test default timeout --- baselines/common/tests/test_with_mpi.py | 2 +- baselines/common/vec_env/__init__.py | 25 ++++++++++++++++++ baselines/common/vec_env/shmem_vec_env.py | 29 ++++++++++++--------- baselines/common/vec_env/subproc_vec_env.py | 23 +--------------- 4 files changed, 43 insertions(+), 36 deletions(-) diff --git a/baselines/common/tests/test_with_mpi.py b/baselines/common/tests/test_with_mpi.py index 0d71a64..3178a5a 100644 --- a/baselines/common/tests/test_with_mpi.py +++ b/baselines/common/tests/test_with_mpi.py @@ -7,7 +7,7 @@ import pytest from mpi4py import MPI -def test_with_mpi(nproc=2, timeout=5, skip_if_no_mpi=True): +def test_with_mpi(nproc=2, timeout=10, skip_if_no_mpi=True): def outer_thunk(fn): def thunk(*args, **kwargs): serialized_fn = base64.b64encode(cloudpickle.dumps(lambda: fn(*args, **kwargs))) diff --git a/baselines/common/vec_env/__init__.py b/baselines/common/vec_env/__init__.py index 9c3703e..2817fad 100644 --- a/baselines/common/vec_env/__init__.py +++ b/baselines/common/vec_env/__init__.py @@ -1,6 +1,10 @@ +import contextlib +import os from abc import ABC, abstractmethod + from baselines.common.tile_images import tile_images + class AlreadySteppingError(Exception): """ Raised when an asynchronous step is running while @@ -181,6 +185,7 @@ class VecEnvObservationWrapper(VecEnvWrapper): obs, rews, dones, infos = self.venv.step_wait() return self.process(obs), rews, dones, infos + class CloudpickleWrapper(object): """ Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) @@ -196,3 +201,23 @@ class CloudpickleWrapper(object): def __setstate__(self, ob): import pickle self.x = pickle.loads(ob) + + +@contextlib.contextmanager +def clear_mpi_env_vars(): + """ + from mpi4py import MPI will call MPI_Init by default. If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang. + + This context manager is a hacky way to clear those environment variables temporarily such as when we are starting multiprocessing + Processes. + """ + removed_environment = {} + for k, v in list(os.environ.items()): + for prefix in ['OMPI_', 'PMI_']: + if k.startswith(prefix): + removed_environment[k] = v + del os.environ[k] + try: + yield + finally: + os.environ.update(removed_environment) \ No newline at end of file diff --git a/baselines/common/vec_env/shmem_vec_env.py b/baselines/common/vec_env/shmem_vec_env.py index fcdcf47..986d078 100644 --- a/baselines/common/vec_env/shmem_vec_env.py +++ b/baselines/common/vec_env/shmem_vec_env.py @@ -2,9 +2,9 @@ An interface for asynchronous vectorized environments. """ -from multiprocessing import Pipe, Array, Process +import multiprocessing as mp import numpy as np -from . import VecEnv, CloudpickleWrapper +from . import VecEnv, CloudpickleWrapper, clear_mpi_env_vars import ctypes from baselines import logger @@ -16,6 +16,8 @@ _NP_TO_CT = {np.float32: ctypes.c_float, np.uint8: ctypes.c_char, np.bool: ctypes.c_bool} +ctx = mp.get_context('spawn') + class ShmemVecEnv(VecEnv): """ @@ -39,20 +41,21 @@ class ShmemVecEnv(VecEnv): VecEnv.__init__(self, len(env_fns), observation_space, action_space) self.obs_keys, self.obs_shapes, self.obs_dtypes = obs_space_info(observation_space) self.obs_bufs = [ - {k: Array(_NP_TO_CT[self.obs_dtypes[k].type], int(np.prod(self.obs_shapes[k]))) for k in self.obs_keys} + {k: ctx.Array(_NP_TO_CT[self.obs_dtypes[k].type], int(np.prod(self.obs_shapes[k]))) for k in self.obs_keys} for _ in env_fns] self.parent_pipes = [] self.procs = [] - for env_fn, obs_buf in zip(env_fns, self.obs_bufs): - wrapped_fn = CloudpickleWrapper(env_fn) - parent_pipe, child_pipe = Pipe() - proc = Process(target=_subproc_worker, - args=(child_pipe, parent_pipe, wrapped_fn, obs_buf, self.obs_shapes, self.obs_dtypes, self.obs_keys)) - proc.daemon = True - self.procs.append(proc) - self.parent_pipes.append(parent_pipe) - proc.start() - child_pipe.close() + with clear_mpi_env_vars(): + for env_fn, obs_buf in zip(env_fns, self.obs_bufs): + wrapped_fn = CloudpickleWrapper(env_fn) + parent_pipe, child_pipe = ctx.Pipe() + proc = ctx.Process(target=_subproc_worker, + args=(child_pipe, parent_pipe, wrapped_fn, obs_buf, self.obs_shapes, self.obs_dtypes, self.obs_keys)) + proc.daemon = True + self.procs.append(proc) + self.parent_pipes.append(parent_pipe) + proc.start() + child_pipe.close() self.waiting_step = False self.viewer = None diff --git a/baselines/common/vec_env/subproc_vec_env.py b/baselines/common/vec_env/subproc_vec_env.py index 2d72508..72cdb93 100644 --- a/baselines/common/vec_env/subproc_vec_env.py +++ b/baselines/common/vec_env/subproc_vec_env.py @@ -1,28 +1,7 @@ -import contextlib import multiprocessing as mp -import os import numpy as np -from . import VecEnv, CloudpickleWrapper - -@contextlib.contextmanager -def clear_mpi_env_vars(): - """ - from mpi4py import MPI will call MPI_Init by default. If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang. - - This context manager is a hacky way to clear those environment variables temporarily such as when we are starting multiprocessing - Processes. - """ - removed_environment = {} - for k, v in list(os.environ.items()): - for prefix in ['OMPI_', 'PMI_']: - if k.startswith(prefix): - removed_environment[k] = v - del os.environ[k] - try: - yield - finally: - os.environ.update(removed_environment) +from . import VecEnv, CloudpickleWrapper, clear_mpi_env_vars ctx = mp.get_context('spawn')