use spawn for shmem vec env as well (#2) (#219)

* lazy_mpi load

* cleanups

* more lazy mpi

* don't pretend that class is a module, just use it as a class

* mass-replace mpi4py imports

* flake8

* fix previous lazy_mpi imports

* silly recursion

* try os.environ hack

* better prefix test, work with mpich

* restored MPI imports

* removed commented import in test_with_mpi

* restored codegen from master

* remove lazy mpi

* restored changes from rl-algs

* remove extra files

* port mpi fix to shmem vec env

* increase the mpi test default timeout
This commit is contained in:
Christopher Hesse
2019-01-15 13:05:04 -08:00
committed by Peter Zhokhov
parent a538e3c8f7
commit 8c547e5973
4 changed files with 43 additions and 36 deletions

View File

@@ -7,7 +7,7 @@ import pytest
from mpi4py import MPI
def test_with_mpi(nproc=2, timeout=5, skip_if_no_mpi=True):
def test_with_mpi(nproc=2, timeout=10, skip_if_no_mpi=True):
def outer_thunk(fn):
def thunk(*args, **kwargs):
serialized_fn = base64.b64encode(cloudpickle.dumps(lambda: fn(*args, **kwargs)))

View File

@@ -1,6 +1,10 @@
import contextlib
import os
from abc import ABC, abstractmethod
from baselines.common.tile_images import tile_images
class AlreadySteppingError(Exception):
"""
Raised when an asynchronous step is running while
@@ -181,6 +185,7 @@ class VecEnvObservationWrapper(VecEnvWrapper):
obs, rews, dones, infos = self.venv.step_wait()
return self.process(obs), rews, dones, infos
class CloudpickleWrapper(object):
"""
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
@@ -196,3 +201,23 @@ class CloudpickleWrapper(object):
def __setstate__(self, ob):
import pickle
self.x = pickle.loads(ob)
@contextlib.contextmanager
def clear_mpi_env_vars():
"""
from mpi4py import MPI will call MPI_Init by default. If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang.
This context manager is a hacky way to clear those environment variables temporarily such as when we are starting multiprocessing
Processes.
"""
removed_environment = {}
for k, v in list(os.environ.items()):
for prefix in ['OMPI_', 'PMI_']:
if k.startswith(prefix):
removed_environment[k] = v
del os.environ[k]
try:
yield
finally:
os.environ.update(removed_environment)

View File

@@ -2,9 +2,9 @@
An interface for asynchronous vectorized environments.
"""
from multiprocessing import Pipe, Array, Process
import multiprocessing as mp
import numpy as np
from . import VecEnv, CloudpickleWrapper
from . import VecEnv, CloudpickleWrapper, clear_mpi_env_vars
import ctypes
from baselines import logger
@@ -16,6 +16,8 @@ _NP_TO_CT = {np.float32: ctypes.c_float,
np.uint8: ctypes.c_char,
np.bool: ctypes.c_bool}
ctx = mp.get_context('spawn')
class ShmemVecEnv(VecEnv):
"""
@@ -39,20 +41,21 @@ class ShmemVecEnv(VecEnv):
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
self.obs_keys, self.obs_shapes, self.obs_dtypes = obs_space_info(observation_space)
self.obs_bufs = [
{k: Array(_NP_TO_CT[self.obs_dtypes[k].type], int(np.prod(self.obs_shapes[k]))) for k in self.obs_keys}
{k: ctx.Array(_NP_TO_CT[self.obs_dtypes[k].type], int(np.prod(self.obs_shapes[k]))) for k in self.obs_keys}
for _ in env_fns]
self.parent_pipes = []
self.procs = []
for env_fn, obs_buf in zip(env_fns, self.obs_bufs):
wrapped_fn = CloudpickleWrapper(env_fn)
parent_pipe, child_pipe = Pipe()
proc = Process(target=_subproc_worker,
args=(child_pipe, parent_pipe, wrapped_fn, obs_buf, self.obs_shapes, self.obs_dtypes, self.obs_keys))
proc.daemon = True
self.procs.append(proc)
self.parent_pipes.append(parent_pipe)
proc.start()
child_pipe.close()
with clear_mpi_env_vars():
for env_fn, obs_buf in zip(env_fns, self.obs_bufs):
wrapped_fn = CloudpickleWrapper(env_fn)
parent_pipe, child_pipe = ctx.Pipe()
proc = ctx.Process(target=_subproc_worker,
args=(child_pipe, parent_pipe, wrapped_fn, obs_buf, self.obs_shapes, self.obs_dtypes, self.obs_keys))
proc.daemon = True
self.procs.append(proc)
self.parent_pipes.append(parent_pipe)
proc.start()
child_pipe.close()
self.waiting_step = False
self.viewer = None

View File

@@ -1,28 +1,7 @@
import contextlib
import multiprocessing as mp
import os
import numpy as np
from . import VecEnv, CloudpickleWrapper
@contextlib.contextmanager
def clear_mpi_env_vars():
"""
from mpi4py import MPI will call MPI_Init by default. If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang.
This context manager is a hacky way to clear those environment variables temporarily such as when we are starting multiprocessing
Processes.
"""
removed_environment = {}
for k, v in list(os.environ.items()):
for prefix in ['OMPI_', 'PMI_']:
if k.startswith(prefix):
removed_environment[k] = v
del os.environ[k]
try:
yield
finally:
os.environ.update(removed_environment)
from . import VecEnv, CloudpickleWrapper, clear_mpi_env_vars
ctx = mp.get_context('spawn')