Reduce duplication in VecEnv subclasses. (#38)
* Reduce duplication in VecEnv subclasses. Now VecEnv base class handles rendering and closing; subclasses should provide get_images and (optionally) close_extras. * fix tests * minor docstring change * raise NotImplementedError
This commit is contained in:
committed by
Peter Zhokhov
parent
c8f6d8bac7
commit
14c1d69ef4
@@ -1,6 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from baselines import logger
|
||||
|
||||
from baselines.common.tile_images import tile_images
|
||||
|
||||
class AlreadySteppingError(Exception):
|
||||
"""
|
||||
@@ -33,6 +32,8 @@ class VecEnv(ABC):
|
||||
self.num_envs = num_envs
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.closed = False
|
||||
self.viewer = None # For rendering
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
@@ -72,13 +73,21 @@ class VecEnv(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
def close_extras(self):
|
||||
"""
|
||||
Clean up the environments' resources.
|
||||
Clean up the extra resources, beyond what's in this base class.
|
||||
Only runs when not self.closed.
|
||||
"""
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
if self.closed:
|
||||
return
|
||||
if self.viewer is not None:
|
||||
self.viewer.close()
|
||||
self.close_extras()
|
||||
self.closed = True
|
||||
|
||||
def step(self, actions):
|
||||
"""
|
||||
Step the environments synchronously.
|
||||
@@ -89,7 +98,20 @@ class VecEnv(ABC):
|
||||
return self.step_wait()
|
||||
|
||||
def render(self, mode='human'):
|
||||
logger.warn('Render not defined for %s' % self)
|
||||
imgs = self.get_images()
|
||||
bigimg = tile_images(imgs)
|
||||
if mode == 'human':
|
||||
self.get_viewer().imshow(bigimg)
|
||||
elif mode == 'rgb_array':
|
||||
return bigimg
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_images(self):
|
||||
"""
|
||||
Return RGB images from each environment
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def unwrapped(self):
|
||||
@@ -98,6 +120,12 @@ class VecEnv(ABC):
|
||||
else:
|
||||
return self
|
||||
|
||||
def get_viewer(self):
|
||||
if self.viewer is None:
|
||||
from gym.envs.classic_control import rendering
|
||||
self.viewer = rendering.SimpleImageViewer()
|
||||
return self.viewer
|
||||
|
||||
|
||||
class VecEnvWrapper(VecEnv):
|
||||
"""
|
||||
@@ -126,9 +154,11 @@ class VecEnvWrapper(VecEnv):
|
||||
def close(self):
|
||||
return self.venv.close()
|
||||
|
||||
def render(self):
|
||||
self.venv.render()
|
||||
def render(self, mode='human'):
|
||||
return self.venv.render(mode=mode)
|
||||
|
||||
def get_images(self):
|
||||
return self.venv.get_images()
|
||||
|
||||
class CloudpickleWrapper(object):
|
||||
"""
|
||||
|
@@ -53,9 +53,6 @@ class DummyVecEnv(VecEnv):
|
||||
def close(self):
|
||||
return
|
||||
|
||||
def render(self, mode='human'):
|
||||
return [e.render(mode=mode) for e in self.envs]
|
||||
|
||||
def _save_obs(self, e, obs):
|
||||
for k in self.keys:
|
||||
if k is None:
|
||||
@@ -65,4 +62,7 @@ class DummyVecEnv(VecEnv):
|
||||
|
||||
def _obs_from_buf(self):
|
||||
return dict_to_obs(copy_obs_dict(self.buf_obs))
|
||||
|
||||
def get_images(self):
|
||||
return [env.render(mode='rgb') for env in self.envs]
|
||||
|
||||
|
@@ -7,7 +7,6 @@ import numpy as np
|
||||
from . import VecEnv, CloudpickleWrapper
|
||||
import ctypes
|
||||
from baselines import logger
|
||||
from baselines.common.tile_images import tile_images
|
||||
|
||||
from .util import dict_to_obs, obs_space_info, obs_to_dict
|
||||
|
||||
@@ -76,7 +75,7 @@ class ShmemVecEnv(VecEnv):
|
||||
obs, rews, dones, infos = zip(*outs)
|
||||
return self._decode_obses(obs), np.array(rews), np.array(dones), infos
|
||||
|
||||
def close(self):
|
||||
def close_extras(self):
|
||||
if self.waiting_step:
|
||||
self.step_wait()
|
||||
for pipe in self.parent_pipes:
|
||||
@@ -86,24 +85,11 @@ class ShmemVecEnv(VecEnv):
|
||||
pipe.close()
|
||||
for proc in self.procs:
|
||||
proc.join()
|
||||
if self.viewer is not None:
|
||||
self.viewer.close()
|
||||
|
||||
def render(self, mode='human'):
|
||||
def get_images(self, mode='human'):
|
||||
for pipe in self.parent_pipes:
|
||||
pipe.send(('render', None))
|
||||
imgs = [pipe.recv() for pipe in self.parent_pipes]
|
||||
bigimg = tile_images(imgs)
|
||||
if mode == 'human':
|
||||
if self.viewer is None:
|
||||
from gym.envs.classic_control import rendering
|
||||
self.viewer = rendering.SimpleImageViewer()
|
||||
|
||||
self.viewer.imshow(bigimg[:, :, ::-1])
|
||||
elif mode == 'rgb_array':
|
||||
return bigimg
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return [pipe.recv() for pipe in self.parent_pipes]
|
||||
|
||||
def _decode_obses(self, obs):
|
||||
result = {}
|
||||
|
@@ -1,8 +1,6 @@
|
||||
import numpy as np
|
||||
from multiprocessing import Process, Pipe
|
||||
from . import VecEnv, CloudpickleWrapper
|
||||
from baselines.common.tile_images import tile_images
|
||||
|
||||
|
||||
def worker(remote, parent_remote, env_fn_wrapper):
|
||||
parent_remote.close()
|
||||
@@ -39,7 +37,6 @@ class SubprocVecEnv(VecEnv):
|
||||
envs: list of gym environments to run in subprocesses
|
||||
"""
|
||||
self.waiting = False
|
||||
self.closed = False
|
||||
nenvs = len(env_fns)
|
||||
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
|
||||
self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
|
||||
@@ -76,9 +73,7 @@ class SubprocVecEnv(VecEnv):
|
||||
remote.send(('reset_task', None))
|
||||
return np.stack([remote.recv() for remote in self.remotes])
|
||||
|
||||
def close(self):
|
||||
if self.closed:
|
||||
return
|
||||
def close_extras(self):
|
||||
if self.waiting:
|
||||
for remote in self.remotes:
|
||||
remote.recv()
|
||||
@@ -86,23 +81,9 @@ class SubprocVecEnv(VecEnv):
|
||||
remote.send(('close', None))
|
||||
for p in self.ps:
|
||||
p.join()
|
||||
if self.viewer is not None:
|
||||
self.viewer.close()
|
||||
self.closed = True
|
||||
|
||||
def render(self, mode='human'):
|
||||
def get_images(self):
|
||||
for pipe in self.remotes:
|
||||
pipe.send(('render', None))
|
||||
imgs = [pipe.recv() for pipe in self.remotes]
|
||||
bigimg = tile_images(imgs)
|
||||
if mode == 'human':
|
||||
if self.viewer is None:
|
||||
from gym.envs.classic_control import rendering
|
||||
self.viewer = rendering.SimpleImageViewer()
|
||||
|
||||
self.viewer.imshow(bigimg[:, :, ::-1])
|
||||
|
||||
elif mode == 'rgb_array':
|
||||
return bigimg
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return imgs
|
||||
|
Reference in New Issue
Block a user