From 14c1d69ef439241d409f1b69619feefbfd6cb645 Mon Sep 17 00:00:00 2001 From: John Schulman Date: Wed, 22 Aug 2018 13:54:34 -0700 Subject: [PATCH] Reduce duplication in VecEnv subclasses. (#38) * Reduce duplication in VecEnv subclasses. Now VecEnv base class handles rendering and closing; subclasses should provide get_images and (optionally) close_extras. * fix tests * minor docstring change * raise NotImplementedError --- baselines/common/vec_env/__init__.py | 46 +++++++++++++++++---- baselines/common/vec_env/dummy_vec_env.py | 6 +-- baselines/common/vec_env/shmem_vec_env.py | 20 ++------- baselines/common/vec_env/subproc_vec_env.py | 25 ++--------- 4 files changed, 47 insertions(+), 50 deletions(-) diff --git a/baselines/common/vec_env/__init__.py b/baselines/common/vec_env/__init__.py index c2d987b..37bc78e 100644 --- a/baselines/common/vec_env/__init__.py +++ b/baselines/common/vec_env/__init__.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod -from baselines import logger - +from baselines.common.tile_images import tile_images class AlreadySteppingError(Exception): """ @@ -33,6 +32,8 @@ class VecEnv(ABC): self.num_envs = num_envs self.observation_space = observation_space self.action_space = action_space + self.closed = False + self.viewer = None # For rendering @abstractmethod def reset(self): @@ -72,13 +73,21 @@ class VecEnv(ABC): """ pass - @abstractmethod - def close(self): + def close_extras(self): """ - Clean up the environments' resources. + Clean up the extra resources, beyond what's in this base class. + Only runs when not self.closed. """ pass + def close(self): + if self.closed: + return + if self.viewer is not None: + self.viewer.close() + self.close_extras() + self.closed = True + def step(self, actions): """ Step the environments synchronously. @@ -89,7 +98,20 @@ class VecEnv(ABC): return self.step_wait() def render(self, mode='human'): - logger.warn('Render not defined for %s' % self) + imgs = self.get_images() + bigimg = tile_images(imgs) + if mode == 'human': + self.get_viewer().imshow(bigimg) + elif mode == 'rgb_array': + return bigimg + else: + raise NotImplementedError + + def get_images(self): + """ + Return RGB images from each environment + """ + raise NotImplementedError @property def unwrapped(self): @@ -98,6 +120,12 @@ class VecEnv(ABC): else: return self + def get_viewer(self): + if self.viewer is None: + from gym.envs.classic_control import rendering + self.viewer = rendering.SimpleImageViewer() + return self.viewer + class VecEnvWrapper(VecEnv): """ @@ -126,9 +154,11 @@ class VecEnvWrapper(VecEnv): def close(self): return self.venv.close() - def render(self): - self.venv.render() + def render(self, mode='human'): + return self.venv.render(mode=mode) + def get_images(self): + return self.venv.get_images() class CloudpickleWrapper(object): """ diff --git a/baselines/common/vec_env/dummy_vec_env.py b/baselines/common/vec_env/dummy_vec_env.py index 7a77a5a..af60f76 100644 --- a/baselines/common/vec_env/dummy_vec_env.py +++ b/baselines/common/vec_env/dummy_vec_env.py @@ -53,9 +53,6 @@ class DummyVecEnv(VecEnv): def close(self): return - def render(self, mode='human'): - return [e.render(mode=mode) for e in self.envs] - def _save_obs(self, e, obs): for k in self.keys: if k is None: @@ -65,4 +62,7 @@ class DummyVecEnv(VecEnv): def _obs_from_buf(self): return dict_to_obs(copy_obs_dict(self.buf_obs)) + + def get_images(self): + return [env.render(mode='rgb') for env in self.envs] diff --git a/baselines/common/vec_env/shmem_vec_env.py b/baselines/common/vec_env/shmem_vec_env.py index cc352f8..b2c8e16 100644 --- a/baselines/common/vec_env/shmem_vec_env.py +++ b/baselines/common/vec_env/shmem_vec_env.py @@ -7,7 +7,6 @@ import numpy as np from . import VecEnv, CloudpickleWrapper import ctypes from baselines import logger -from baselines.common.tile_images import tile_images from .util import dict_to_obs, obs_space_info, obs_to_dict @@ -76,7 +75,7 @@ class ShmemVecEnv(VecEnv): obs, rews, dones, infos = zip(*outs) return self._decode_obses(obs), np.array(rews), np.array(dones), infos - def close(self): + def close_extras(self): if self.waiting_step: self.step_wait() for pipe in self.parent_pipes: @@ -86,24 +85,11 @@ class ShmemVecEnv(VecEnv): pipe.close() for proc in self.procs: proc.join() - if self.viewer is not None: - self.viewer.close() - def render(self, mode='human'): + def get_images(self, mode='human'): for pipe in self.parent_pipes: pipe.send(('render', None)) - imgs = [pipe.recv() for pipe in self.parent_pipes] - bigimg = tile_images(imgs) - if mode == 'human': - if self.viewer is None: - from gym.envs.classic_control import rendering - self.viewer = rendering.SimpleImageViewer() - - self.viewer.imshow(bigimg[:, :, ::-1]) - elif mode == 'rgb_array': - return bigimg - else: - raise NotImplementedError + return [pipe.recv() for pipe in self.parent_pipes] def _decode_obses(self, obs): result = {} diff --git a/baselines/common/vec_env/subproc_vec_env.py b/baselines/common/vec_env/subproc_vec_env.py index 949f485..c49330b 100644 --- a/baselines/common/vec_env/subproc_vec_env.py +++ b/baselines/common/vec_env/subproc_vec_env.py @@ -1,8 +1,6 @@ import numpy as np from multiprocessing import Process, Pipe from . import VecEnv, CloudpickleWrapper -from baselines.common.tile_images import tile_images - def worker(remote, parent_remote, env_fn_wrapper): parent_remote.close() @@ -39,7 +37,6 @@ class SubprocVecEnv(VecEnv): envs: list of gym environments to run in subprocesses """ self.waiting = False - self.closed = False nenvs = len(env_fns) self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)]) self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) @@ -76,9 +73,7 @@ class SubprocVecEnv(VecEnv): remote.send(('reset_task', None)) return np.stack([remote.recv() for remote in self.remotes]) - def close(self): - if self.closed: - return + def close_extras(self): if self.waiting: for remote in self.remotes: remote.recv() @@ -86,23 +81,9 @@ class SubprocVecEnv(VecEnv): remote.send(('close', None)) for p in self.ps: p.join() - if self.viewer is not None: - self.viewer.close() - self.closed = True - def render(self, mode='human'): + def get_images(self): for pipe in self.remotes: pipe.send(('render', None)) imgs = [pipe.recv() for pipe in self.remotes] - bigimg = tile_images(imgs) - if mode == 'human': - if self.viewer is None: - from gym.envs.classic_control import rendering - self.viewer = rendering.SimpleImageViewer() - - self.viewer.imshow(bigimg[:, :, ::-1]) - - elif mode == 'rgb_array': - return bigimg - else: - raise NotImplementedError + return imgs