Reduce duplication in VecEnv subclasses. (#38)

* Reduce duplication in VecEnv subclasses.
Now VecEnv base class handles rendering and closing; subclasses should provide get_images and (optionally) close_extras.

* fix tests

* minor docstring change

* raise NotImplementedError
This commit is contained in:
John Schulman
2018-08-22 13:54:34 -07:00
committed by Peter Zhokhov
parent c8f6d8bac7
commit 14c1d69ef4
4 changed files with 47 additions and 50 deletions

View File

@@ -1,6 +1,5 @@
from abc import ABC, abstractmethod
from baselines import logger
from baselines.common.tile_images import tile_images
class AlreadySteppingError(Exception):
"""
@@ -33,6 +32,8 @@ class VecEnv(ABC):
self.num_envs = num_envs
self.observation_space = observation_space
self.action_space = action_space
self.closed = False
self.viewer = None # For rendering
@abstractmethod
def reset(self):
@@ -72,13 +73,21 @@ class VecEnv(ABC):
"""
pass
@abstractmethod
def close(self):
def close_extras(self):
"""
Clean up the environments' resources.
Clean up the extra resources, beyond what's in this base class.
Only runs when not self.closed.
"""
pass
def close(self):
if self.closed:
return
if self.viewer is not None:
self.viewer.close()
self.close_extras()
self.closed = True
def step(self, actions):
"""
Step the environments synchronously.
@@ -89,7 +98,20 @@ class VecEnv(ABC):
return self.step_wait()
def render(self, mode='human'):
logger.warn('Render not defined for %s' % self)
imgs = self.get_images()
bigimg = tile_images(imgs)
if mode == 'human':
self.get_viewer().imshow(bigimg)
elif mode == 'rgb_array':
return bigimg
else:
raise NotImplementedError
def get_images(self):
"""
Return RGB images from each environment
"""
raise NotImplementedError
@property
def unwrapped(self):
@@ -98,6 +120,12 @@ class VecEnv(ABC):
else:
return self
def get_viewer(self):
if self.viewer is None:
from gym.envs.classic_control import rendering
self.viewer = rendering.SimpleImageViewer()
return self.viewer
class VecEnvWrapper(VecEnv):
"""
@@ -126,9 +154,11 @@ class VecEnvWrapper(VecEnv):
def close(self):
return self.venv.close()
def render(self):
self.venv.render()
def render(self, mode='human'):
return self.venv.render(mode=mode)
def get_images(self):
return self.venv.get_images()
class CloudpickleWrapper(object):
"""

View File

@@ -53,9 +53,6 @@ class DummyVecEnv(VecEnv):
def close(self):
return
def render(self, mode='human'):
return [e.render(mode=mode) for e in self.envs]
def _save_obs(self, e, obs):
for k in self.keys:
if k is None:
@@ -65,4 +62,7 @@ class DummyVecEnv(VecEnv):
def _obs_from_buf(self):
return dict_to_obs(copy_obs_dict(self.buf_obs))
def get_images(self):
return [env.render(mode='rgb') for env in self.envs]

View File

@@ -7,7 +7,6 @@ import numpy as np
from . import VecEnv, CloudpickleWrapper
import ctypes
from baselines import logger
from baselines.common.tile_images import tile_images
from .util import dict_to_obs, obs_space_info, obs_to_dict
@@ -76,7 +75,7 @@ class ShmemVecEnv(VecEnv):
obs, rews, dones, infos = zip(*outs)
return self._decode_obses(obs), np.array(rews), np.array(dones), infos
def close(self):
def close_extras(self):
if self.waiting_step:
self.step_wait()
for pipe in self.parent_pipes:
@@ -86,24 +85,11 @@ class ShmemVecEnv(VecEnv):
pipe.close()
for proc in self.procs:
proc.join()
if self.viewer is not None:
self.viewer.close()
def render(self, mode='human'):
def get_images(self, mode='human'):
for pipe in self.parent_pipes:
pipe.send(('render', None))
imgs = [pipe.recv() for pipe in self.parent_pipes]
bigimg = tile_images(imgs)
if mode == 'human':
if self.viewer is None:
from gym.envs.classic_control import rendering
self.viewer = rendering.SimpleImageViewer()
self.viewer.imshow(bigimg[:, :, ::-1])
elif mode == 'rgb_array':
return bigimg
else:
raise NotImplementedError
return [pipe.recv() for pipe in self.parent_pipes]
def _decode_obses(self, obs):
result = {}

View File

@@ -1,8 +1,6 @@
import numpy as np
from multiprocessing import Process, Pipe
from . import VecEnv, CloudpickleWrapper
from baselines.common.tile_images import tile_images
def worker(remote, parent_remote, env_fn_wrapper):
parent_remote.close()
@@ -39,7 +37,6 @@ class SubprocVecEnv(VecEnv):
envs: list of gym environments to run in subprocesses
"""
self.waiting = False
self.closed = False
nenvs = len(env_fns)
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
@@ -76,9 +73,7 @@ class SubprocVecEnv(VecEnv):
remote.send(('reset_task', None))
return np.stack([remote.recv() for remote in self.remotes])
def close(self):
if self.closed:
return
def close_extras(self):
if self.waiting:
for remote in self.remotes:
remote.recv()
@@ -86,23 +81,9 @@ class SubprocVecEnv(VecEnv):
remote.send(('close', None))
for p in self.ps:
p.join()
if self.viewer is not None:
self.viewer.close()
self.closed = True
def render(self, mode='human'):
def get_images(self):
for pipe in self.remotes:
pipe.send(('render', None))
imgs = [pipe.recv() for pipe in self.remotes]
bigimg = tile_images(imgs)
if mode == 'human':
if self.viewer is None:
from gym.envs.classic_control import rendering
self.viewer = rendering.SimpleImageViewer()
self.viewer.imshow(bigimg[:, :, ::-1])
elif mode == 'rgb_array':
return bigimg
else:
raise NotImplementedError
return imgs