Files
baselines/baselines/common/vec_env/subproc_vec_env.py
pzhokhov a538e3c8f7 disable mpi in subprocesses (#213)
* lazy_mpi load

* cleanups

* more lazy mpi

* don't pretend that class is a module, just use it as a class

* mass-replace mpi4py imports

* flake8

* fix previous lazy_mpi imports

* silly recursion

* try os.environ hack

* better prefix test, work with mpich

* restored MPI imports

* removed commented import in test_with_mpi

* restored codegen from master

* remove lazy mpi

* restored changes from rl-algs

* remove extra files

* address Chris' comments
2019-01-24 14:29:35 -08:00

141 lines
4.8 KiB
Python

import contextlib
import multiprocessing as mp
import os
import numpy as np
from . import VecEnv, CloudpickleWrapper
@contextlib.contextmanager
def clear_mpi_env_vars():
"""
from mpi4py import MPI will call MPI_Init by default. If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang.
This context manager is a hacky way to clear those environment variables temporarily such as when we are starting multiprocessing
Processes.
"""
removed_environment = {}
for k, v in list(os.environ.items()):
for prefix in ['OMPI_', 'PMI_']:
if k.startswith(prefix):
removed_environment[k] = v
del os.environ[k]
try:
yield
finally:
os.environ.update(removed_environment)
ctx = mp.get_context('spawn')
def worker(remote, parent_remote, env_fn_wrapper):
parent_remote.close()
env = env_fn_wrapper.x()
try:
while True:
cmd, data = remote.recv()
if cmd == 'step':
ob, reward, done, info = env.step(data)
if done:
ob = env.reset()
remote.send((ob, reward, done, info))
elif cmd == 'reset':
ob = env.reset()
remote.send(ob)
elif cmd == 'render':
remote.send(env.render(mode='rgb_array'))
elif cmd == 'close':
remote.close()
break
elif cmd == 'get_spaces_spec':
remote.send((env.observation_space, env.action_space, env.spec))
else:
raise NotImplementedError
except KeyboardInterrupt:
print('SubprocVecEnv worker: got KeyboardInterrupt')
finally:
env.close()
class SubprocVecEnv(VecEnv):
"""
VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes.
Recommended to use when num_envs > 1 and step() can be a bottleneck.
"""
def __init__(self, env_fns, spaces=None):
"""
Arguments:
env_fns: iterable of callables - functions that create environments to run in subprocesses. Need to be cloud-pickleable
"""
self.waiting = False
self.closed = False
nenvs = len(env_fns)
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(nenvs)])
self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
for p in self.ps:
p.daemon = True # if the main process crashes, we should not cause things to hang
with clear_mpi_env_vars():
p.start()
for remote in self.work_remotes:
remote.close()
self.remotes[0].send(('get_spaces_spec', None))
observation_space, action_space, self.spec = self.remotes[0].recv()
self.viewer = None
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
def step_async(self, actions):
self._assert_not_closed()
for remote, action in zip(self.remotes, actions):
remote.send(('step', action))
self.waiting = True
def step_wait(self):
self._assert_not_closed()
results = [remote.recv() for remote in self.remotes]
self.waiting = False
obs, rews, dones, infos = zip(*results)
return _flatten_obs(obs), np.stack(rews), np.stack(dones), infos
def reset(self):
self._assert_not_closed()
for remote in self.remotes:
remote.send(('reset', None))
return _flatten_obs([remote.recv() for remote in self.remotes])
def close_extras(self):
self.closed = True
if self.waiting:
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
remote.send(('close', None))
for p in self.ps:
p.join()
def get_images(self):
self._assert_not_closed()
for pipe in self.remotes:
pipe.send(('render', None))
imgs = [pipe.recv() for pipe in self.remotes]
return imgs
def _assert_not_closed(self):
assert not self.closed, "Trying to operate on a SubprocVecEnv after calling close()"
def __del__(self):
if not self.closed:
self.close()
def _flatten_obs(obs):
assert isinstance(obs, (list, tuple))
assert len(obs) > 0
if isinstance(obs[0], dict):
import collections
assert isinstance(obs, collections.OrderedDict)
keys = obs[0].keys()
return {k: np.stack([o[k] for o in obs]) for k in keys}
else:
return np.stack(obs)