* lazy_mpi load * cleanups * more lazy mpi * don't pretend that class is a module, just use it as a class * mass-replace mpi4py imports * flake8 * fix previous lazy_mpi imports * silly recursion * try os.environ hack * better prefix test, work with mpich * restored MPI imports * removed commented import in test_with_mpi * restored codegen from master * remove lazy mpi * restored changes from rl-algs * remove extra files * port mpi fix to shmem vec env * increase the mpi test default timeout
This commit is contained in:
committed by
Peter Zhokhov
parent
a538e3c8f7
commit
8c547e5973
@@ -7,7 +7,7 @@ import pytest
|
||||
|
||||
from mpi4py import MPI
|
||||
|
||||
def test_with_mpi(nproc=2, timeout=5, skip_if_no_mpi=True):
|
||||
def test_with_mpi(nproc=2, timeout=10, skip_if_no_mpi=True):
|
||||
def outer_thunk(fn):
|
||||
def thunk(*args, **kwargs):
|
||||
serialized_fn = base64.b64encode(cloudpickle.dumps(lambda: fn(*args, **kwargs)))
|
||||
|
@@ -1,6 +1,10 @@
|
||||
import contextlib
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from baselines.common.tile_images import tile_images
|
||||
|
||||
|
||||
class AlreadySteppingError(Exception):
|
||||
"""
|
||||
Raised when an asynchronous step is running while
|
||||
@@ -181,6 +185,7 @@ class VecEnvObservationWrapper(VecEnvWrapper):
|
||||
obs, rews, dones, infos = self.venv.step_wait()
|
||||
return self.process(obs), rews, dones, infos
|
||||
|
||||
|
||||
class CloudpickleWrapper(object):
|
||||
"""
|
||||
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
|
||||
@@ -196,3 +201,23 @@ class CloudpickleWrapper(object):
|
||||
def __setstate__(self, ob):
|
||||
import pickle
|
||||
self.x = pickle.loads(ob)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def clear_mpi_env_vars():
|
||||
"""
|
||||
from mpi4py import MPI will call MPI_Init by default. If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang.
|
||||
|
||||
This context manager is a hacky way to clear those environment variables temporarily such as when we are starting multiprocessing
|
||||
Processes.
|
||||
"""
|
||||
removed_environment = {}
|
||||
for k, v in list(os.environ.items()):
|
||||
for prefix in ['OMPI_', 'PMI_']:
|
||||
if k.startswith(prefix):
|
||||
removed_environment[k] = v
|
||||
del os.environ[k]
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.environ.update(removed_environment)
|
@@ -2,9 +2,9 @@
|
||||
An interface for asynchronous vectorized environments.
|
||||
"""
|
||||
|
||||
from multiprocessing import Pipe, Array, Process
|
||||
import multiprocessing as mp
|
||||
import numpy as np
|
||||
from . import VecEnv, CloudpickleWrapper
|
||||
from . import VecEnv, CloudpickleWrapper, clear_mpi_env_vars
|
||||
import ctypes
|
||||
from baselines import logger
|
||||
|
||||
@@ -16,6 +16,8 @@ _NP_TO_CT = {np.float32: ctypes.c_float,
|
||||
np.uint8: ctypes.c_char,
|
||||
np.bool: ctypes.c_bool}
|
||||
|
||||
ctx = mp.get_context('spawn')
|
||||
|
||||
|
||||
class ShmemVecEnv(VecEnv):
|
||||
"""
|
||||
@@ -39,20 +41,21 @@ class ShmemVecEnv(VecEnv):
|
||||
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
|
||||
self.obs_keys, self.obs_shapes, self.obs_dtypes = obs_space_info(observation_space)
|
||||
self.obs_bufs = [
|
||||
{k: Array(_NP_TO_CT[self.obs_dtypes[k].type], int(np.prod(self.obs_shapes[k]))) for k in self.obs_keys}
|
||||
{k: ctx.Array(_NP_TO_CT[self.obs_dtypes[k].type], int(np.prod(self.obs_shapes[k]))) for k in self.obs_keys}
|
||||
for _ in env_fns]
|
||||
self.parent_pipes = []
|
||||
self.procs = []
|
||||
for env_fn, obs_buf in zip(env_fns, self.obs_bufs):
|
||||
wrapped_fn = CloudpickleWrapper(env_fn)
|
||||
parent_pipe, child_pipe = Pipe()
|
||||
proc = Process(target=_subproc_worker,
|
||||
args=(child_pipe, parent_pipe, wrapped_fn, obs_buf, self.obs_shapes, self.obs_dtypes, self.obs_keys))
|
||||
proc.daemon = True
|
||||
self.procs.append(proc)
|
||||
self.parent_pipes.append(parent_pipe)
|
||||
proc.start()
|
||||
child_pipe.close()
|
||||
with clear_mpi_env_vars():
|
||||
for env_fn, obs_buf in zip(env_fns, self.obs_bufs):
|
||||
wrapped_fn = CloudpickleWrapper(env_fn)
|
||||
parent_pipe, child_pipe = ctx.Pipe()
|
||||
proc = ctx.Process(target=_subproc_worker,
|
||||
args=(child_pipe, parent_pipe, wrapped_fn, obs_buf, self.obs_shapes, self.obs_dtypes, self.obs_keys))
|
||||
proc.daemon = True
|
||||
self.procs.append(proc)
|
||||
self.parent_pipes.append(parent_pipe)
|
||||
proc.start()
|
||||
child_pipe.close()
|
||||
self.waiting_step = False
|
||||
self.viewer = None
|
||||
|
||||
|
@@ -1,28 +1,7 @@
|
||||
import contextlib
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from . import VecEnv, CloudpickleWrapper
|
||||
|
||||
@contextlib.contextmanager
|
||||
def clear_mpi_env_vars():
|
||||
"""
|
||||
from mpi4py import MPI will call MPI_Init by default. If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang.
|
||||
|
||||
This context manager is a hacky way to clear those environment variables temporarily such as when we are starting multiprocessing
|
||||
Processes.
|
||||
"""
|
||||
removed_environment = {}
|
||||
for k, v in list(os.environ.items()):
|
||||
for prefix in ['OMPI_', 'PMI_']:
|
||||
if k.startswith(prefix):
|
||||
removed_environment[k] = v
|
||||
del os.environ[k]
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.environ.update(removed_environment)
|
||||
from . import VecEnv, CloudpickleWrapper, clear_mpi_env_vars
|
||||
|
||||
ctx = mp.get_context('spawn')
|
||||
|
||||
|
Reference in New Issue
Block a user