add fast failure when calling methods on a closed subprocvecenv (#84)
This commit is contained in:
@@ -37,6 +37,7 @@ class SubprocVecEnv(VecEnv):
|
|||||||
envs: list of gym environments to run in subprocesses
|
envs: list of gym environments to run in subprocesses
|
||||||
"""
|
"""
|
||||||
self.waiting = False
|
self.waiting = False
|
||||||
|
self.closed = False
|
||||||
nenvs = len(env_fns)
|
nenvs = len(env_fns)
|
||||||
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
|
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
|
||||||
self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
|
self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
|
||||||
@@ -53,22 +54,26 @@ class SubprocVecEnv(VecEnv):
|
|||||||
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
|
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
|
||||||
|
|
||||||
def step_async(self, actions):
|
def step_async(self, actions):
|
||||||
|
self._assert_not_closed()
|
||||||
for remote, action in zip(self.remotes, actions):
|
for remote, action in zip(self.remotes, actions):
|
||||||
remote.send(('step', action))
|
remote.send(('step', action))
|
||||||
self.waiting = True
|
self.waiting = True
|
||||||
|
|
||||||
def step_wait(self):
|
def step_wait(self):
|
||||||
|
self._assert_not_closed()
|
||||||
results = [remote.recv() for remote in self.remotes]
|
results = [remote.recv() for remote in self.remotes]
|
||||||
self.waiting = False
|
self.waiting = False
|
||||||
obs, rews, dones, infos = zip(*results)
|
obs, rews, dones, infos = zip(*results)
|
||||||
return np.stack(obs), np.stack(rews), np.stack(dones), infos
|
return np.stack(obs), np.stack(rews), np.stack(dones), infos
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
self._assert_not_closed()
|
||||||
for remote in self.remotes:
|
for remote in self.remotes:
|
||||||
remote.send(('reset', None))
|
remote.send(('reset', None))
|
||||||
return np.stack([remote.recv() for remote in self.remotes])
|
return np.stack([remote.recv() for remote in self.remotes])
|
||||||
|
|
||||||
def close_extras(self):
|
def close_extras(self):
|
||||||
|
self.closed = True
|
||||||
if self.waiting:
|
if self.waiting:
|
||||||
for remote in self.remotes:
|
for remote in self.remotes:
|
||||||
remote.recv()
|
remote.recv()
|
||||||
@@ -78,7 +83,11 @@ class SubprocVecEnv(VecEnv):
|
|||||||
p.join()
|
p.join()
|
||||||
|
|
||||||
def get_images(self):
|
def get_images(self):
|
||||||
|
self._assert_not_closed()
|
||||||
for pipe in self.remotes:
|
for pipe in self.remotes:
|
||||||
pipe.send(('render', None))
|
pipe.send(('render', None))
|
||||||
imgs = [pipe.recv() for pipe in self.remotes]
|
imgs = [pipe.recv() for pipe in self.remotes]
|
||||||
return imgs
|
return imgs
|
||||||
|
|
||||||
|
def _assert_not_closed(self):
|
||||||
|
assert not self.closed, "Trying to operate on a SubprocVecEnv after calling close()"
|
||||||
|
Reference in New Issue
Block a user