disable mpi in subprocesses (#213)

* lazy_mpi load

* cleanups

* more lazy mpi

* don't pretend that class is a module, just use it as a class

* mass-replace mpi4py imports

* flake8

* fix previous lazy_mpi imports

* silly recursion

* try os.environ hack

* better prefix test, work with mpich

* restored MPI imports

* removed commented import in test_with_mpi

* restored codegen from master

* remove lazy mpi

* restored changes from rl-algs

* remove extra files

* address Chris' comments
This commit is contained in:
pzhokhov
2019-01-15 09:59:27 -08:00
committed by Peter Zhokhov
parent 3a8f35a7e9
commit a538e3c8f7
4 changed files with 74 additions and 5 deletions

View File

@@ -35,7 +35,8 @@ def make_vec_env(env_id, env_type, num_env, seed,
return lambda: make_env(
env_id=env_id,
env_type=env_type,
subrank = rank,
mpi_rank=mpi_rank,
subrank=rank,
seed=seed,
reward_scale=reward_scale,
gamestate=gamestate,
@@ -51,8 +52,7 @@ def make_vec_env(env_id, env_type, num_env, seed,
return DummyVecEnv([make_thunk(start_index)])
def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, logger_dir=None):
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, logger_dir=None):
wrapper_kwargs = wrapper_kwargs or {}
if env_type == 'atari':
env = make_atari(env_id)

View File

@@ -0,0 +1,33 @@
import os
import sys
import subprocess
import cloudpickle
import base64
import pytest
from mpi4py import MPI
def test_with_mpi(nproc=2, timeout=5, skip_if_no_mpi=True):
def outer_thunk(fn):
def thunk(*args, **kwargs):
serialized_fn = base64.b64encode(cloudpickle.dumps(lambda: fn(*args, **kwargs)))
subprocess.check_call([
'mpiexec','-n', str(nproc),
sys.executable,
'-m', 'baselines.common.tests.test_with_mpi',
serialized_fn
], env=os.environ, timeout=timeout)
if skip_if_no_mpi:
return pytest.mark.skipif(MPI is None, reason="MPI not present")(thunk)
else:
return thunk
return outer_thunk
if __name__ == '__main__':
if len(sys.argv) > 1:
fn = cloudpickle.loads(base64.b64decode(sys.argv[1]))
assert callable(fn)
fn()

View File

@@ -1,7 +1,29 @@
import numpy as np
import contextlib
import multiprocessing as mp
import os
import numpy as np
from . import VecEnv, CloudpickleWrapper
@contextlib.contextmanager
def clear_mpi_env_vars():
"""
from mpi4py import MPI will call MPI_Init by default. If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang.
This context manager is a hacky way to clear those environment variables temporarily such as when we are starting multiprocessing
Processes.
"""
removed_environment = {}
for k, v in list(os.environ.items()):
for prefix in ['OMPI_', 'PMI_']:
if k.startswith(prefix):
removed_environment[k] = v
del os.environ[k]
try:
yield
finally:
os.environ.update(removed_environment)
ctx = mp.get_context('spawn')
def worker(remote, parent_remote, env_fn_wrapper):
@@ -52,7 +74,8 @@ class SubprocVecEnv(VecEnv):
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
for p in self.ps:
p.daemon = True # if the main process crashes, we should not cause things to hang
p.start()
with clear_mpi_env_vars():
p.start()
for remote in self.work_remotes:
remote.close()

View File

@@ -8,6 +8,7 @@ import pytest
from .dummy_vec_env import DummyVecEnv
from .shmem_vec_env import ShmemVecEnv
from .subproc_vec_env import SubprocVecEnv
from baselines.common.tests.test_with_mpi import test_with_mpi
def assert_envs_equal(env1, env2, num_steps):
@@ -99,3 +100,15 @@ class SimpleEnv(gym.Env):
def render(self, mode=None):
raise NotImplementedError
@test_with_mpi()
def test_mpi_with_subprocvecenv():
shape = (2,3,4)
nenv = 1
venv = SubprocVecEnv([lambda: SimpleEnv(0, shape, 'float32')] * nenv)
ob = venv.reset()
venv.close()
assert ob.shape == (nenv,) + shape