Make SubprocVecEnv works with DummyVecEnv (#908)

* Make SubprocVecEnv works with DummyVecEnv (nested environments for synchronous sampling)

* SubprocVecEnv now supports running environments in series in each process

* Added docstring to the test definition

* Added additional test to check, whether SubprocVecEnv results with the same output when in_series parameter is enabled and not

* Added more test cases for in_series parameter

* Refactored worker function, added docstring for in_series parameter

* Remove check for TF presence in setup.py
This commit is contained in:
Tomasz Wrona
2019-08-29 21:16:25 +02:00
committed by pzhokhov
parent 0182fe1877
commit d80b075904
2 changed files with 80 additions and 15 deletions

View File

@@ -4,32 +4,35 @@ import numpy as np
from .vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars from .vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars
def worker(remote, parent_remote, env_fn_wrapper): def worker(remote, parent_remote, env_fn_wrappers):
def step_env(env, action):
ob, reward, done, info = env.step(action)
if done:
ob = env.reset()
return ob, reward, done, info
parent_remote.close() parent_remote.close()
env = env_fn_wrapper.x() envs = [env_fn_wrapper() for env_fn_wrapper in env_fn_wrappers.x]
try: try:
while True: while True:
cmd, data = remote.recv() cmd, data = remote.recv()
if cmd == 'step': if cmd == 'step':
ob, reward, done, info = env.step(data) remote.send([step_env(env, action) for env, action in zip(envs, data)])
if done:
ob = env.reset()
remote.send((ob, reward, done, info))
elif cmd == 'reset': elif cmd == 'reset':
ob = env.reset() remote.send([env.reset() for env in envs])
remote.send(ob)
elif cmd == 'render': elif cmd == 'render':
remote.send(env.render(mode='rgb_array')) remote.send([env.render(mode='rgb_array') for env in envs])
elif cmd == 'close': elif cmd == 'close':
remote.close() remote.close()
break break
elif cmd == 'get_spaces_spec': elif cmd == 'get_spaces_spec':
remote.send((env.observation_space, env.action_space, env.spec)) remote.send((envs[0].observation_space, envs[0].action_space, envs[0].spec))
else: else:
raise NotImplementedError raise NotImplementedError
except KeyboardInterrupt: except KeyboardInterrupt:
print('SubprocVecEnv worker: got KeyboardInterrupt') print('SubprocVecEnv worker: got KeyboardInterrupt')
finally: finally:
for env in envs:
env.close() env.close()
@@ -38,17 +41,23 @@ class SubprocVecEnv(VecEnv):
VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes. VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes.
Recommended to use when num_envs > 1 and step() can be a bottleneck. Recommended to use when num_envs > 1 and step() can be a bottleneck.
""" """
def __init__(self, env_fns, spaces=None, context='spawn'): def __init__(self, env_fns, spaces=None, context='spawn', in_series=1):
""" """
Arguments: Arguments:
env_fns: iterable of callables - functions that create environments to run in subprocesses. Need to be cloud-pickleable env_fns: iterable of callables - functions that create environments to run in subprocesses. Need to be cloud-pickleable
in_series: number of environments to run in series in a single process
(e.g. when len(env_fns) == 12 and in_series == 3, it will run 4 processes, each running 3 envs in series)
""" """
self.waiting = False self.waiting = False
self.closed = False self.closed = False
self.in_series = in_series
nenvs = len(env_fns) nenvs = len(env_fns)
assert nenvs % in_series == 0, "Number of envs must be divisible by number of envs to run in series"
self.nremotes = nenvs // in_series
env_fns = np.array_split(env_fns, self.nremotes)
ctx = mp.get_context(context) ctx = mp.get_context(context)
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(nenvs)]) self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(self.nremotes)])
self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
for p in self.ps: for p in self.ps:
@@ -61,10 +70,11 @@ class SubprocVecEnv(VecEnv):
self.remotes[0].send(('get_spaces_spec', None)) self.remotes[0].send(('get_spaces_spec', None))
observation_space, action_space, self.spec = self.remotes[0].recv() observation_space, action_space, self.spec = self.remotes[0].recv()
self.viewer = None self.viewer = None
VecEnv.__init__(self, len(env_fns), observation_space, action_space) VecEnv.__init__(self, nenvs, observation_space, action_space)
def step_async(self, actions): def step_async(self, actions):
self._assert_not_closed() self._assert_not_closed()
actions = np.array_split(actions, self.nremotes)
for remote, action in zip(self.remotes, actions): for remote, action in zip(self.remotes, actions):
remote.send(('step', action)) remote.send(('step', action))
self.waiting = True self.waiting = True
@@ -72,6 +82,7 @@ class SubprocVecEnv(VecEnv):
def step_wait(self): def step_wait(self):
self._assert_not_closed() self._assert_not_closed()
results = [remote.recv() for remote in self.remotes] results = [remote.recv() for remote in self.remotes]
results = _flatten_list(results)
self.waiting = False self.waiting = False
obs, rews, dones, infos = zip(*results) obs, rews, dones, infos = zip(*results)
return _flatten_obs(obs), np.stack(rews), np.stack(dones), infos return _flatten_obs(obs), np.stack(rews), np.stack(dones), infos
@@ -80,7 +91,9 @@ class SubprocVecEnv(VecEnv):
self._assert_not_closed() self._assert_not_closed()
for remote in self.remotes: for remote in self.remotes:
remote.send(('reset', None)) remote.send(('reset', None))
return _flatten_obs([remote.recv() for remote in self.remotes]) obs = [remote.recv() for remote in self.remotes]
obs = _flatten_list(obs)
return _flatten_obs(obs)
def close_extras(self): def close_extras(self):
self.closed = True self.closed = True
@@ -97,6 +110,7 @@ class SubprocVecEnv(VecEnv):
for pipe in self.remotes: for pipe in self.remotes:
pipe.send(('render', None)) pipe.send(('render', None))
imgs = [pipe.recv() for pipe in self.remotes] imgs = [pipe.recv() for pipe in self.remotes]
imgs = _flatten_list(imgs)
return imgs return imgs
def _assert_not_closed(self): def _assert_not_closed(self):
@@ -115,3 +129,10 @@ def _flatten_obs(obs):
return {k: np.stack([o[k] for o in obs]) for k in keys} return {k: np.stack([o[k] for o in obs]) for k in keys}
else: else:
return np.stack(obs) return np.stack(obs)
def _flatten_list(l):
assert isinstance(l, (list, tuple))
assert len(l) > 0
assert all([len(l_) > 0 for l_ in l])
return [l__ for l_ in l for l__ in l_]

View File

@@ -67,6 +67,50 @@ def test_vec_env(klass, dtype): # pylint: disable=R0914
assert_venvs_equal(env1, env2, num_steps=num_steps) assert_venvs_equal(env1, env2, num_steps=num_steps)
@pytest.mark.parametrize('dtype', ('uint8', 'float32'))
@pytest.mark.parametrize('num_envs_in_series', (3, 4, 6))
def test_sync_sampling(dtype, num_envs_in_series):
"""
Test that a SubprocVecEnv running with envs in series
outputs the same as DummyVecEnv.
"""
num_envs = 12
num_steps = 100
shape = (3, 8)
def make_fn(seed):
"""
Get an environment constructor with a seed.
"""
return lambda: SimpleEnv(seed, shape, dtype)
fns = [make_fn(i) for i in range(num_envs)]
env1 = DummyVecEnv(fns)
env2 = SubprocVecEnv(fns, in_series=num_envs_in_series)
assert_venvs_equal(env1, env2, num_steps=num_steps)
@pytest.mark.parametrize('dtype', ('uint8', 'float32'))
@pytest.mark.parametrize('num_envs_in_series', (3, 4, 6))
def test_sync_sampling_sanity(dtype, num_envs_in_series):
"""
Test that a SubprocVecEnv running with envs in series
outputs the same as SubprocVecEnv without running in series.
"""
num_envs = 12
num_steps = 100
shape = (3, 8)
def make_fn(seed):
"""
Get an environment constructor with a seed.
"""
return lambda: SimpleEnv(seed, shape, dtype)
fns = [make_fn(i) for i in range(num_envs)]
env1 = SubprocVecEnv(fns)
env2 = SubprocVecEnv(fns, in_series=num_envs_in_series)
assert_venvs_equal(env1, env2, num_steps=num_steps)
class SimpleEnv(gym.Env): class SimpleEnv(gym.Env):
""" """
An environment with a pre-determined observation space An environment with a pre-determined observation space