Reduce duplication in VecEnv subclasses. (#38)
* Reduce duplication in VecEnv subclasses. Now VecEnv base class handles rendering and closing; subclasses should provide get_images and (optionally) close_extras. * fix tests * minor docstring change * raise NotImplementedError
This commit is contained in:
committed by
Peter Zhokhov
parent
c8f6d8bac7
commit
14c1d69ef4
@@ -1,6 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from baselines import logger
|
from baselines.common.tile_images import tile_images
|
||||||
|
|
||||||
|
|
||||||
class AlreadySteppingError(Exception):
|
class AlreadySteppingError(Exception):
|
||||||
"""
|
"""
|
||||||
@@ -33,6 +32,8 @@ class VecEnv(ABC):
|
|||||||
self.num_envs = num_envs
|
self.num_envs = num_envs
|
||||||
self.observation_space = observation_space
|
self.observation_space = observation_space
|
||||||
self.action_space = action_space
|
self.action_space = action_space
|
||||||
|
self.closed = False
|
||||||
|
self.viewer = None # For rendering
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def reset(self):
|
def reset(self):
|
||||||
@@ -72,13 +73,21 @@ class VecEnv(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
def close_extras(self):
|
||||||
def close(self):
|
|
||||||
"""
|
"""
|
||||||
Clean up the environments' resources.
|
Clean up the extra resources, beyond what's in this base class.
|
||||||
|
Only runs when not self.closed.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self.closed:
|
||||||
|
return
|
||||||
|
if self.viewer is not None:
|
||||||
|
self.viewer.close()
|
||||||
|
self.close_extras()
|
||||||
|
self.closed = True
|
||||||
|
|
||||||
def step(self, actions):
|
def step(self, actions):
|
||||||
"""
|
"""
|
||||||
Step the environments synchronously.
|
Step the environments synchronously.
|
||||||
@@ -89,7 +98,20 @@ class VecEnv(ABC):
|
|||||||
return self.step_wait()
|
return self.step_wait()
|
||||||
|
|
||||||
def render(self, mode='human'):
|
def render(self, mode='human'):
|
||||||
logger.warn('Render not defined for %s' % self)
|
imgs = self.get_images()
|
||||||
|
bigimg = tile_images(imgs)
|
||||||
|
if mode == 'human':
|
||||||
|
self.get_viewer().imshow(bigimg)
|
||||||
|
elif mode == 'rgb_array':
|
||||||
|
return bigimg
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_images(self):
|
||||||
|
"""
|
||||||
|
Return RGB images from each environment
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unwrapped(self):
|
def unwrapped(self):
|
||||||
@@ -98,6 +120,12 @@ class VecEnv(ABC):
|
|||||||
else:
|
else:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def get_viewer(self):
|
||||||
|
if self.viewer is None:
|
||||||
|
from gym.envs.classic_control import rendering
|
||||||
|
self.viewer = rendering.SimpleImageViewer()
|
||||||
|
return self.viewer
|
||||||
|
|
||||||
|
|
||||||
class VecEnvWrapper(VecEnv):
|
class VecEnvWrapper(VecEnv):
|
||||||
"""
|
"""
|
||||||
@@ -126,9 +154,11 @@ class VecEnvWrapper(VecEnv):
|
|||||||
def close(self):
|
def close(self):
|
||||||
return self.venv.close()
|
return self.venv.close()
|
||||||
|
|
||||||
def render(self):
|
def render(self, mode='human'):
|
||||||
self.venv.render()
|
return self.venv.render(mode=mode)
|
||||||
|
|
||||||
|
def get_images(self):
|
||||||
|
return self.venv.get_images()
|
||||||
|
|
||||||
class CloudpickleWrapper(object):
|
class CloudpickleWrapper(object):
|
||||||
"""
|
"""
|
||||||
|
@@ -53,9 +53,6 @@ class DummyVecEnv(VecEnv):
|
|||||||
def close(self):
|
def close(self):
|
||||||
return
|
return
|
||||||
|
|
||||||
def render(self, mode='human'):
|
|
||||||
return [e.render(mode=mode) for e in self.envs]
|
|
||||||
|
|
||||||
def _save_obs(self, e, obs):
|
def _save_obs(self, e, obs):
|
||||||
for k in self.keys:
|
for k in self.keys:
|
||||||
if k is None:
|
if k is None:
|
||||||
@@ -65,4 +62,7 @@ class DummyVecEnv(VecEnv):
|
|||||||
|
|
||||||
def _obs_from_buf(self):
|
def _obs_from_buf(self):
|
||||||
return dict_to_obs(copy_obs_dict(self.buf_obs))
|
return dict_to_obs(copy_obs_dict(self.buf_obs))
|
||||||
|
|
||||||
|
def get_images(self):
|
||||||
|
return [env.render(mode='rgb') for env in self.envs]
|
||||||
|
|
||||||
|
@@ -7,7 +7,6 @@ import numpy as np
|
|||||||
from . import VecEnv, CloudpickleWrapper
|
from . import VecEnv, CloudpickleWrapper
|
||||||
import ctypes
|
import ctypes
|
||||||
from baselines import logger
|
from baselines import logger
|
||||||
from baselines.common.tile_images import tile_images
|
|
||||||
|
|
||||||
from .util import dict_to_obs, obs_space_info, obs_to_dict
|
from .util import dict_to_obs, obs_space_info, obs_to_dict
|
||||||
|
|
||||||
@@ -76,7 +75,7 @@ class ShmemVecEnv(VecEnv):
|
|||||||
obs, rews, dones, infos = zip(*outs)
|
obs, rews, dones, infos = zip(*outs)
|
||||||
return self._decode_obses(obs), np.array(rews), np.array(dones), infos
|
return self._decode_obses(obs), np.array(rews), np.array(dones), infos
|
||||||
|
|
||||||
def close(self):
|
def close_extras(self):
|
||||||
if self.waiting_step:
|
if self.waiting_step:
|
||||||
self.step_wait()
|
self.step_wait()
|
||||||
for pipe in self.parent_pipes:
|
for pipe in self.parent_pipes:
|
||||||
@@ -86,24 +85,11 @@ class ShmemVecEnv(VecEnv):
|
|||||||
pipe.close()
|
pipe.close()
|
||||||
for proc in self.procs:
|
for proc in self.procs:
|
||||||
proc.join()
|
proc.join()
|
||||||
if self.viewer is not None:
|
|
||||||
self.viewer.close()
|
|
||||||
|
|
||||||
def render(self, mode='human'):
|
def get_images(self, mode='human'):
|
||||||
for pipe in self.parent_pipes:
|
for pipe in self.parent_pipes:
|
||||||
pipe.send(('render', None))
|
pipe.send(('render', None))
|
||||||
imgs = [pipe.recv() for pipe in self.parent_pipes]
|
return [pipe.recv() for pipe in self.parent_pipes]
|
||||||
bigimg = tile_images(imgs)
|
|
||||||
if mode == 'human':
|
|
||||||
if self.viewer is None:
|
|
||||||
from gym.envs.classic_control import rendering
|
|
||||||
self.viewer = rendering.SimpleImageViewer()
|
|
||||||
|
|
||||||
self.viewer.imshow(bigimg[:, :, ::-1])
|
|
||||||
elif mode == 'rgb_array':
|
|
||||||
return bigimg
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def _decode_obses(self, obs):
|
def _decode_obses(self, obs):
|
||||||
result = {}
|
result = {}
|
||||||
|
@@ -1,8 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from multiprocessing import Process, Pipe
|
from multiprocessing import Process, Pipe
|
||||||
from . import VecEnv, CloudpickleWrapper
|
from . import VecEnv, CloudpickleWrapper
|
||||||
from baselines.common.tile_images import tile_images
|
|
||||||
|
|
||||||
|
|
||||||
def worker(remote, parent_remote, env_fn_wrapper):
|
def worker(remote, parent_remote, env_fn_wrapper):
|
||||||
parent_remote.close()
|
parent_remote.close()
|
||||||
@@ -39,7 +37,6 @@ class SubprocVecEnv(VecEnv):
|
|||||||
envs: list of gym environments to run in subprocesses
|
envs: list of gym environments to run in subprocesses
|
||||||
"""
|
"""
|
||||||
self.waiting = False
|
self.waiting = False
|
||||||
self.closed = False
|
|
||||||
nenvs = len(env_fns)
|
nenvs = len(env_fns)
|
||||||
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
|
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
|
||||||
self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
|
self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
|
||||||
@@ -76,9 +73,7 @@ class SubprocVecEnv(VecEnv):
|
|||||||
remote.send(('reset_task', None))
|
remote.send(('reset_task', None))
|
||||||
return np.stack([remote.recv() for remote in self.remotes])
|
return np.stack([remote.recv() for remote in self.remotes])
|
||||||
|
|
||||||
def close(self):
|
def close_extras(self):
|
||||||
if self.closed:
|
|
||||||
return
|
|
||||||
if self.waiting:
|
if self.waiting:
|
||||||
for remote in self.remotes:
|
for remote in self.remotes:
|
||||||
remote.recv()
|
remote.recv()
|
||||||
@@ -86,23 +81,9 @@ class SubprocVecEnv(VecEnv):
|
|||||||
remote.send(('close', None))
|
remote.send(('close', None))
|
||||||
for p in self.ps:
|
for p in self.ps:
|
||||||
p.join()
|
p.join()
|
||||||
if self.viewer is not None:
|
|
||||||
self.viewer.close()
|
|
||||||
self.closed = True
|
|
||||||
|
|
||||||
def render(self, mode='human'):
|
def get_images(self):
|
||||||
for pipe in self.remotes:
|
for pipe in self.remotes:
|
||||||
pipe.send(('render', None))
|
pipe.send(('render', None))
|
||||||
imgs = [pipe.recv() for pipe in self.remotes]
|
imgs = [pipe.recv() for pipe in self.remotes]
|
||||||
bigimg = tile_images(imgs)
|
return imgs
|
||||||
if mode == 'human':
|
|
||||||
if self.viewer is None:
|
|
||||||
from gym.envs.classic_control import rendering
|
|
||||||
self.viewer = rendering.SimpleImageViewer()
|
|
||||||
|
|
||||||
self.viewer.imshow(bigimg[:, :, ::-1])
|
|
||||||
|
|
||||||
elif mode == 'rgb_array':
|
|
||||||
return bigimg
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
Reference in New Issue
Block a user