Compare commits
2 Commits
fix_build
...
peterz_ale
Author | SHA1 | Date | |
---|---|---|---|
|
0f281fd0ca | ||
|
ef4146005a |
@@ -1,28 +1,34 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from baselines import logger
|
from baselines import logger
|
||||||
|
|
||||||
|
|
||||||
class AlreadySteppingError(Exception):
|
class AlreadySteppingError(Exception):
|
||||||
"""
|
"""
|
||||||
Raised when an asynchronous step is running while
|
Raised when an asynchronous step is running while
|
||||||
step_async() is called again.
|
step_async() is called again.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
msg = 'already running an async step'
|
msg = 'already running an async step'
|
||||||
Exception.__init__(self, msg)
|
Exception.__init__(self, msg)
|
||||||
|
|
||||||
|
|
||||||
class NotSteppingError(Exception):
|
class NotSteppingError(Exception):
|
||||||
"""
|
"""
|
||||||
Raised when an asynchronous step is not running but
|
Raised when an asynchronous step is not running but
|
||||||
step_wait() is called.
|
step_wait() is called.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
msg = 'not running an async step'
|
msg = 'not running an async step'
|
||||||
Exception.__init__(self, msg)
|
Exception.__init__(self, msg)
|
||||||
|
|
||||||
|
|
||||||
class VecEnv(ABC):
|
class VecEnv(ABC):
|
||||||
"""
|
"""
|
||||||
An abstract asynchronous, vectorized environment.
|
An abstract asynchronous, vectorized environment.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_envs, observation_space, action_space):
|
def __init__(self, num_envs, observation_space, action_space):
|
||||||
self.num_envs = num_envs
|
self.num_envs = num_envs
|
||||||
self.observation_space = observation_space
|
self.observation_space = observation_space
|
||||||
@@ -32,7 +38,7 @@ class VecEnv(ABC):
|
|||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
Reset all the environments and return an array of
|
Reset all the environments and return an array of
|
||||||
observations, or a tuple of observation arrays.
|
observations, or a dict of observation arrays.
|
||||||
|
|
||||||
If step_async is still doing work, that work will
|
If step_async is still doing work, that work will
|
||||||
be cancelled and step_wait() should not be called
|
be cancelled and step_wait() should not be called
|
||||||
@@ -58,7 +64,7 @@ class VecEnv(ABC):
|
|||||||
Wait for the step taken with step_async().
|
Wait for the step taken with step_async().
|
||||||
|
|
||||||
Returns (obs, rews, dones, infos):
|
Returns (obs, rews, dones, infos):
|
||||||
- obs: an array of observations, or a tuple of
|
- obs: an array of observations, or a dict of
|
||||||
arrays of observations.
|
arrays of observations.
|
||||||
- rews: an array of rewards
|
- rews: an array of rewards
|
||||||
- dones: an array of "episode done" booleans
|
- dones: an array of "episode done" booleans
|
||||||
@@ -74,11 +80,16 @@ class VecEnv(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def step(self, actions):
|
def step(self, actions):
|
||||||
|
"""
|
||||||
|
Step the environments synchronously.
|
||||||
|
|
||||||
|
This is available for backwards compatibility.
|
||||||
|
"""
|
||||||
self.step_async(actions)
|
self.step_async(actions)
|
||||||
return self.step_wait()
|
return self.step_wait()
|
||||||
|
|
||||||
def render(self, mode='human'):
|
def render(self, mode='human'):
|
||||||
logger.warn('Render not defined for %s'%self)
|
logger.warn('Render not defined for %s' % self)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unwrapped(self):
|
def unwrapped(self):
|
||||||
@@ -87,13 +98,19 @@ class VecEnv(ABC):
|
|||||||
else:
|
else:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class VecEnvWrapper(VecEnv):
|
class VecEnvWrapper(VecEnv):
|
||||||
|
"""
|
||||||
|
An environment wrapper that applies to an entire batch
|
||||||
|
of environments at once.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, venv, observation_space=None, action_space=None):
|
def __init__(self, venv, observation_space=None, action_space=None):
|
||||||
self.venv = venv
|
self.venv = venv
|
||||||
VecEnv.__init__(self,
|
VecEnv.__init__(self,
|
||||||
num_envs=venv.num_envs,
|
num_envs=venv.num_envs,
|
||||||
observation_space=observation_space or venv.observation_space,
|
observation_space=observation_space or venv.observation_space,
|
||||||
action_space=action_space or venv.action_space)
|
action_space=action_space or venv.action_space)
|
||||||
|
|
||||||
def step_async(self, actions):
|
def step_async(self, actions):
|
||||||
self.venv.step_async(actions)
|
self.venv.step_async(actions)
|
||||||
@@ -112,15 +129,19 @@ class VecEnvWrapper(VecEnv):
|
|||||||
def render(self):
|
def render(self):
|
||||||
self.venv.render()
|
self.venv.render()
|
||||||
|
|
||||||
|
|
||||||
class CloudpickleWrapper(object):
|
class CloudpickleWrapper(object):
|
||||||
"""
|
"""
|
||||||
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
|
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, x):
|
def __init__(self, x):
|
||||||
self.x = x
|
self.x = x
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
return cloudpickle.dumps(self.x)
|
return cloudpickle.dumps(self.x)
|
||||||
|
|
||||||
def __setstate__(self, ob):
|
def __setstate__(self, ob):
|
||||||
import pickle
|
import pickle
|
||||||
self.x = pickle.loads(ob)
|
self.x = pickle.loads(ob)
|
||||||
|
@@ -1,59 +1,40 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from gym import spaces
|
|
||||||
from collections import OrderedDict
|
|
||||||
from . import VecEnv
|
from . import VecEnv
|
||||||
|
from .util import copy_obs_dict, dict_to_obs, obs_space_info
|
||||||
|
|
||||||
|
|
||||||
class DummyVecEnv(VecEnv):
|
class DummyVecEnv(VecEnv):
|
||||||
|
"""
|
||||||
|
A VecEnv that wraps raw gym.Envs.
|
||||||
|
|
||||||
|
This can be used when an algorithm requires a VecEnv
|
||||||
|
but you want to use a vanilla gym.Env instance.
|
||||||
|
It is also useful for avoiding IPC overhead when you
|
||||||
|
don't need to run environments in parallel.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, env_fns):
|
def __init__(self, env_fns):
|
||||||
self.envs = [fn() for fn in env_fns]
|
self.envs = [fn() for fn in env_fns]
|
||||||
env = self.envs[0]
|
env = self.envs[0]
|
||||||
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
|
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
|
||||||
shapes, dtypes = {}, {}
|
|
||||||
self.keys = []
|
|
||||||
obs_space = env.observation_space
|
obs_space = env.observation_space
|
||||||
|
self.keys, shapes, dtypes = obs_space_info(obs_space)
|
||||||
if isinstance(obs_space, spaces.Dict):
|
self.buf_obs = {k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys}
|
||||||
assert isinstance(obs_space.spaces, OrderedDict)
|
|
||||||
subspaces = obs_space.spaces
|
|
||||||
else:
|
|
||||||
subspaces = {None: obs_space}
|
|
||||||
|
|
||||||
for key, box in subspaces.items():
|
|
||||||
shapes[key] = box.shape
|
|
||||||
dtypes[key] = box.dtype
|
|
||||||
self.keys.append(key)
|
|
||||||
|
|
||||||
self.buf_obs = { k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys }
|
|
||||||
self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool)
|
self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool)
|
||||||
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
|
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
|
||||||
self.buf_infos = [{} for _ in range(self.num_envs)]
|
self.buf_infos = [{} for _ in range(self.num_envs)]
|
||||||
self.actions = None
|
self.actions = None
|
||||||
|
|
||||||
def step_async(self, actions):
|
def step_async(self, actions):
|
||||||
listify = True
|
self.actions = actions
|
||||||
try:
|
|
||||||
if len(actions) == self.num_envs:
|
|
||||||
listify = False
|
|
||||||
except TypeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not listify:
|
|
||||||
self.actions = actions
|
|
||||||
else:
|
|
||||||
assert self.num_envs == 1, "actions {} is either not a list or has a wrong size - cannot match to {} environments".format(actions, self.num_envs)
|
|
||||||
self.actions = [actions]
|
|
||||||
|
|
||||||
def step_wait(self):
|
def step_wait(self):
|
||||||
for e in range(self.num_envs):
|
for e in range(self.num_envs):
|
||||||
action = self.actions[e]
|
obs, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e] = self.envs[e].step(self.actions[e])
|
||||||
if isinstance(self.envs[e].action_space, spaces.Discrete):
|
|
||||||
action = int(action)
|
|
||||||
|
|
||||||
obs, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e] = self.envs[e].step(action)
|
|
||||||
if self.buf_dones[e]:
|
if self.buf_dones[e]:
|
||||||
obs = self.envs[e].reset()
|
obs = self.envs[e].reset()
|
||||||
self._save_obs(e, obs)
|
self._save_obs(e, obs)
|
||||||
return (np.copy(self._obs_from_buf()), np.copy(self.buf_rews), np.copy(self.buf_dones),
|
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones),
|
||||||
self.buf_infos.copy())
|
self.buf_infos.copy())
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
@@ -63,7 +44,8 @@ class DummyVecEnv(VecEnv):
|
|||||||
return self._obs_from_buf()
|
return self._obs_from_buf()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
return
|
for e in self.envs:
|
||||||
|
e.close()
|
||||||
|
|
||||||
def render(self, mode='human'):
|
def render(self, mode='human'):
|
||||||
return [e.render(mode=mode) for e in self.envs]
|
return [e.render(mode=mode) for e in self.envs]
|
||||||
@@ -76,7 +58,4 @@ class DummyVecEnv(VecEnv):
|
|||||||
self.buf_obs[k][e] = obs[k]
|
self.buf_obs[k][e] = obs[k]
|
||||||
|
|
||||||
def _obs_from_buf(self):
|
def _obs_from_buf(self):
|
||||||
if self.keys==[None]:
|
return dict_to_obs(copy_obs_dict(self.buf_obs))
|
||||||
return self.buf_obs[None]
|
|
||||||
else:
|
|
||||||
return self.buf_obs
|
|
||||||
|
146
baselines/common/vec_env/shmem_vec_env.py
Normal file
146
baselines/common/vec_env/shmem_vec_env.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
"""
|
||||||
|
An interface for asynchronous vectorized environments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from multiprocessing import Pipe, Array, Process
|
||||||
|
import numpy as np
|
||||||
|
from . import VecEnv, CloudpickleWrapper
|
||||||
|
import ctypes
|
||||||
|
from baselines import logger
|
||||||
|
from baselines.common.tile_images import tile_images
|
||||||
|
|
||||||
|
from .util import dict_to_obs, obs_space_info, obs_to_dict
|
||||||
|
|
||||||
|
_NP_TO_CT = {np.float32: ctypes.c_float,
|
||||||
|
np.int32: ctypes.c_int32,
|
||||||
|
np.int8: ctypes.c_int8,
|
||||||
|
np.uint8: ctypes.c_char,
|
||||||
|
np.bool: ctypes.c_bool}
|
||||||
|
|
||||||
|
|
||||||
|
class ShmemVecEnv(VecEnv):
|
||||||
|
"""
|
||||||
|
An AsyncEnv that uses multiprocessing to run multiple
|
||||||
|
environments in parallel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env_fns, spaces=None):
|
||||||
|
"""
|
||||||
|
If you don't specify observation_space, we'll have to create a dummy
|
||||||
|
environment to get it.
|
||||||
|
"""
|
||||||
|
if spaces:
|
||||||
|
observation_space, action_space = spaces
|
||||||
|
else:
|
||||||
|
logger.log('Creating dummy env object to get spaces')
|
||||||
|
with logger.scoped_configure(format_strs=[]):
|
||||||
|
dummy = env_fns[0]()
|
||||||
|
observation_space, action_space = dummy.observation_space, dummy.action_space
|
||||||
|
dummy.close()
|
||||||
|
del dummy
|
||||||
|
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
|
||||||
|
self.obs_keys, self.obs_shapes, self.obs_dtypes = obs_space_info(observation_space)
|
||||||
|
self.obs_bufs = [
|
||||||
|
{k: Array(_NP_TO_CT[self.obs_dtypes[k].type], int(np.prod(self.obs_shapes[k]))) for k in self.obs_keys}
|
||||||
|
for _ in env_fns]
|
||||||
|
self.parent_pipes = []
|
||||||
|
self.procs = []
|
||||||
|
for env_fn, obs_buf in zip(env_fns, self.obs_bufs):
|
||||||
|
wrapped_fn = CloudpickleWrapper(env_fn)
|
||||||
|
parent_pipe, child_pipe = Pipe()
|
||||||
|
proc = Process(target=_subproc_worker,
|
||||||
|
args=(child_pipe, parent_pipe, wrapped_fn, obs_buf, self.obs_shapes, self.obs_dtypes, self.obs_keys))
|
||||||
|
proc.daemon = True
|
||||||
|
self.procs.append(proc)
|
||||||
|
self.parent_pipes.append(parent_pipe)
|
||||||
|
proc.start()
|
||||||
|
child_pipe.close()
|
||||||
|
self.waiting_step = False
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
if self.waiting_step:
|
||||||
|
logger.warn('Called reset() while waiting for the step to complete')
|
||||||
|
self.step_wait()
|
||||||
|
for pipe in self.parent_pipes:
|
||||||
|
pipe.send(('reset', None))
|
||||||
|
return self._decode_obses([pipe.recv() for pipe in self.parent_pipes])
|
||||||
|
|
||||||
|
def step_async(self, actions):
|
||||||
|
assert len(actions) == len(self.parent_pipes)
|
||||||
|
for pipe, act in zip(self.parent_pipes, actions):
|
||||||
|
pipe.send(('step', act))
|
||||||
|
|
||||||
|
def step_wait(self):
|
||||||
|
outs = [pipe.recv() for pipe in self.parent_pipes]
|
||||||
|
obs, rews, dones, infos = zip(*outs)
|
||||||
|
return self._decode_obses(obs), np.array(rews), np.array(dones), infos
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self.waiting_step:
|
||||||
|
self.step_wait()
|
||||||
|
for pipe in self.parent_pipes:
|
||||||
|
pipe.send(('close', None))
|
||||||
|
for pipe in self.parent_pipes:
|
||||||
|
pipe.recv()
|
||||||
|
pipe.close()
|
||||||
|
for proc in self.procs:
|
||||||
|
proc.join()
|
||||||
|
|
||||||
|
def render(self, mode='human'):
|
||||||
|
for pipe in self.parent_pipes:
|
||||||
|
pipe.send(('render', None))
|
||||||
|
imgs = [pipe.recv() for pipe in self.parent_pipes]
|
||||||
|
bigimg = tile_images(imgs)
|
||||||
|
if mode == 'human':
|
||||||
|
import cv2
|
||||||
|
cv2.imshow('vecenv', bigimg[:, :, ::-1])
|
||||||
|
cv2.waitKey(1)
|
||||||
|
elif mode == 'rgb_array':
|
||||||
|
return bigimg
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _decode_obses(self, obs):
|
||||||
|
result = {}
|
||||||
|
for k in self.obs_keys:
|
||||||
|
bufs = [b[k] for b in self.obs_bufs]
|
||||||
|
o = [np.frombuffer(b.get_obj(), dtype=self.obs_dtypes[k]).reshape(self.obs_shapes[k]) for b in bufs]
|
||||||
|
result[k] = np.array(o)
|
||||||
|
return dict_to_obs(result)
|
||||||
|
|
||||||
|
|
||||||
|
def _subproc_worker(pipe, parent_pipe, env_fn_wrapper, obs_bufs, obs_shapes, obs_dtypes, keys):
|
||||||
|
"""
|
||||||
|
Control a single environment instance using IPC and
|
||||||
|
shared memory.
|
||||||
|
"""
|
||||||
|
def _write_obs(maybe_dict_obs):
|
||||||
|
flatdict = obs_to_dict(maybe_dict_obs)
|
||||||
|
for k in keys:
|
||||||
|
dst = obs_bufs[k].get_obj()
|
||||||
|
dst_np = np.frombuffer(dst, dtype=obs_dtypes[k]).reshape(obs_shapes[k]) # pylint: disable=W0212
|
||||||
|
np.copyto(dst_np, flatdict[k])
|
||||||
|
|
||||||
|
env = env_fn_wrapper.x()
|
||||||
|
parent_pipe.close()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
cmd, data = pipe.recv()
|
||||||
|
if cmd == 'reset':
|
||||||
|
pipe.send(_write_obs(env.reset()))
|
||||||
|
elif cmd == 'step':
|
||||||
|
obs, reward, done, info = env.step(data)
|
||||||
|
if done:
|
||||||
|
obs = env.reset()
|
||||||
|
pipe.send((_write_obs(obs), reward, done, info))
|
||||||
|
elif cmd == 'render':
|
||||||
|
pipe.send(env.render(mode='rgb_array'))
|
||||||
|
elif cmd == 'close':
|
||||||
|
pipe.send(None)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise RuntimeError('Got unrecognized cmd %s' % cmd)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print('ShmemVecEnv worker: got KeyboardInterrupt')
|
||||||
|
finally:
|
||||||
|
env.close()
|
@@ -1,6 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from multiprocessing import Process, Pipe
|
from multiprocessing import Process, Pipe
|
||||||
from baselines.common.vec_env import VecEnv, CloudpickleWrapper
|
from . import VecEnv, CloudpickleWrapper
|
||||||
from baselines.common.tile_images import tile_images
|
from baselines.common.tile_images import tile_images
|
||||||
|
|
||||||
|
|
||||||
@@ -32,6 +32,7 @@ def worker(remote, parent_remote, env_fn_wrapper):
|
|||||||
finally:
|
finally:
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
class SubprocVecEnv(VecEnv):
|
class SubprocVecEnv(VecEnv):
|
||||||
def __init__(self, env_fns, spaces=None):
|
def __init__(self, env_fns, spaces=None):
|
||||||
"""
|
"""
|
||||||
@@ -42,9 +43,9 @@ class SubprocVecEnv(VecEnv):
|
|||||||
nenvs = len(env_fns)
|
nenvs = len(env_fns)
|
||||||
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
|
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
|
||||||
self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
|
self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
|
||||||
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
|
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
|
||||||
for p in self.ps:
|
for p in self.ps:
|
||||||
p.daemon = True # if the main process crashes, we should not cause things to hang
|
p.daemon = True # if the main process crashes, we should not cause things to hang
|
||||||
p.start()
|
p.start()
|
||||||
for remote in self.work_remotes:
|
for remote in self.work_remotes:
|
||||||
remote.close()
|
remote.close()
|
||||||
@@ -78,7 +79,7 @@ class SubprocVecEnv(VecEnv):
|
|||||||
if self.closed:
|
if self.closed:
|
||||||
return
|
return
|
||||||
if self.waiting:
|
if self.waiting:
|
||||||
for remote in self.remotes:
|
for remote in self.remotes:
|
||||||
remote.recv()
|
remote.recv()
|
||||||
for remote in self.remotes:
|
for remote in self.remotes:
|
||||||
remote.send(('close', None))
|
remote.send(('close', None))
|
||||||
@@ -93,9 +94,9 @@ class SubprocVecEnv(VecEnv):
|
|||||||
bigimg = tile_images(imgs)
|
bigimg = tile_images(imgs)
|
||||||
if mode == 'human':
|
if mode == 'human':
|
||||||
import cv2
|
import cv2
|
||||||
cv2.imshow('vecenv', bigimg[:,:,::-1])
|
cv2.imshow('vecenv', bigimg[:, :, ::-1])
|
||||||
cv2.waitKey(1)
|
cv2.waitKey(1)
|
||||||
elif mode == 'rgb_array':
|
elif mode == 'rgb_array':
|
||||||
return bigimg
|
return bigimg
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
85
baselines/common/vec_env/test_vec_env.py
Normal file
85
baselines/common/vec_env/test_vec_env.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""
|
||||||
|
Tests for asynchronous vectorized environments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from .dummy_vec_env import DummyVecEnv
|
||||||
|
from .shmem_vec_env import ShmemVecEnv
|
||||||
|
from .subproc_vec_env import SubprocVecEnv
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('klass', (ShmemVecEnv, SubprocVecEnv))
|
||||||
|
@pytest.mark.parametrize('dtype', ('uint8', 'float32'))
|
||||||
|
def test_vec_env(klass, dtype): # pylint: disable=R0914
|
||||||
|
"""
|
||||||
|
Test that a vectorized environment is equivalent to
|
||||||
|
DummyVecEnv, since DummyVecEnv is less likely to be
|
||||||
|
error prone.
|
||||||
|
"""
|
||||||
|
num_envs = 3
|
||||||
|
num_steps = 100
|
||||||
|
shape = (3, 8)
|
||||||
|
|
||||||
|
def make_fn(seed):
|
||||||
|
"""
|
||||||
|
Get an environment constructor with a seed.
|
||||||
|
"""
|
||||||
|
return lambda: _SimpleEnv(seed, shape, dtype)
|
||||||
|
fns = [make_fn(i) for i in range(num_envs)]
|
||||||
|
env1 = DummyVecEnv(fns)
|
||||||
|
env2 = klass(fns)
|
||||||
|
try:
|
||||||
|
obs1, obs2 = env1.reset(), env2.reset()
|
||||||
|
assert np.array(obs1).shape == np.array(obs2).shape
|
||||||
|
assert np.allclose(obs1, obs2)
|
||||||
|
np.random.seed(1337)
|
||||||
|
for _ in range(num_steps):
|
||||||
|
joint_shape = (len(fns),) + shape
|
||||||
|
actions = np.array(np.random.randint(0, 0x100, size=joint_shape),
|
||||||
|
dtype=dtype)
|
||||||
|
for env in [env1, env2]:
|
||||||
|
env.step_async(actions)
|
||||||
|
outs1 = env1.step_wait()
|
||||||
|
outs2 = env2.step_wait()
|
||||||
|
for out1, out2 in zip(outs1[:3], outs2[:3]):
|
||||||
|
assert np.array(out1).shape == np.array(out2).shape
|
||||||
|
assert np.allclose(out1, out2)
|
||||||
|
assert list(outs1[3]) == list(outs2[3])
|
||||||
|
finally:
|
||||||
|
env1.close()
|
||||||
|
env2.close()
|
||||||
|
|
||||||
|
|
||||||
|
class _SimpleEnv(gym.Env):
|
||||||
|
"""
|
||||||
|
An environment with a pre-determined observation space
|
||||||
|
and RNG seed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, seed, shape, dtype):
|
||||||
|
np.random.seed(seed)
|
||||||
|
self._dtype = dtype
|
||||||
|
self._start_obs = np.array(np.random.randint(0, 0x100, size=shape),
|
||||||
|
dtype=dtype)
|
||||||
|
self._max_steps = seed + 1
|
||||||
|
self._cur_obs = None
|
||||||
|
self._cur_step = 0
|
||||||
|
self.action_space = gym.spaces.Box(low=0, high=100, shape=shape, dtype=dtype)
|
||||||
|
self.observation_space = self.action_space
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
self._cur_obs += np.array(action, dtype=self._dtype)
|
||||||
|
self._cur_step += 1
|
||||||
|
done = self._cur_step >= self._max_steps
|
||||||
|
reward = self._cur_step / self._max_steps
|
||||||
|
return self._cur_obs, reward, done, {'foo': 'bar' + str(reward)}
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._cur_obs = self._start_obs
|
||||||
|
self._cur_step = 0
|
||||||
|
return self._cur_obs
|
||||||
|
|
||||||
|
def render(self, mode=None):
|
||||||
|
raise NotImplementedError
|
59
baselines/common/vec_env/util.py
Normal file
59
baselines/common/vec_env/util.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
"""
|
||||||
|
Helpers for dealing with vectorized environments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def copy_obs_dict(obs):
|
||||||
|
"""
|
||||||
|
Deep-copy an observation dict.
|
||||||
|
"""
|
||||||
|
return {k: np.copy(v) for k, v in obs.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def dict_to_obs(obs_dict):
|
||||||
|
"""
|
||||||
|
Convert an observation dict into a raw array if the
|
||||||
|
original observation space was not a Dict space.
|
||||||
|
"""
|
||||||
|
if set(obs_dict.keys()) == {None}:
|
||||||
|
return obs_dict[None]
|
||||||
|
return obs_dict
|
||||||
|
|
||||||
|
|
||||||
|
def obs_space_info(obs_space):
|
||||||
|
"""
|
||||||
|
Get dict-structured information about a gym.Space.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple (keys, shapes, dtypes):
|
||||||
|
keys: a list of dict keys.
|
||||||
|
shapes: a dict mapping keys to shapes.
|
||||||
|
dtypes: a dict mapping keys to dtypes.
|
||||||
|
"""
|
||||||
|
if isinstance(obs_space, gym.spaces.Dict):
|
||||||
|
assert isinstance(obs_space.spaces, OrderedDict)
|
||||||
|
subspaces = obs_space.spaces
|
||||||
|
else:
|
||||||
|
subspaces = {None: obs_space}
|
||||||
|
keys = []
|
||||||
|
shapes = {}
|
||||||
|
dtypes = {}
|
||||||
|
for key, box in subspaces.items():
|
||||||
|
keys.append(key)
|
||||||
|
shapes[key] = box.shape
|
||||||
|
dtypes[key] = box.dtype
|
||||||
|
return keys, shapes, dtypes
|
||||||
|
|
||||||
|
|
||||||
|
def obs_to_dict(obs):
|
||||||
|
"""
|
||||||
|
Convert an observation into a dict.
|
||||||
|
"""
|
||||||
|
if isinstance(obs, dict):
|
||||||
|
return obs
|
||||||
|
return {None: obs}
|
@@ -1,18 +1,16 @@
|
|||||||
from baselines.common.vec_env import VecEnvWrapper
|
from . import VecEnvWrapper
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
|
|
||||||
|
|
||||||
class VecFrameStack(VecEnvWrapper):
|
class VecFrameStack(VecEnvWrapper):
|
||||||
"""
|
|
||||||
Vectorized environment base class
|
|
||||||
"""
|
|
||||||
def __init__(self, venv, nstack):
|
def __init__(self, venv, nstack):
|
||||||
self.venv = venv
|
self.venv = venv
|
||||||
self.nstack = nstack
|
self.nstack = nstack
|
||||||
wos = venv.observation_space # wrapped ob space
|
wos = venv.observation_space # wrapped ob space
|
||||||
low = np.repeat(wos.low, self.nstack, axis=-1)
|
low = np.repeat(wos.low, self.nstack, axis=-1)
|
||||||
high = np.repeat(wos.high, self.nstack, axis=-1)
|
high = np.repeat(wos.high, self.nstack, axis=-1)
|
||||||
self.stackedobs = np.zeros((venv.num_envs,)+low.shape, low.dtype)
|
self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype)
|
||||||
observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype)
|
observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype)
|
||||||
VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
|
VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
|
||||||
|
|
||||||
@@ -26,9 +24,6 @@ class VecFrameStack(VecEnvWrapper):
|
|||||||
return self.stackedobs, rews, news, infos
|
return self.stackedobs, rews, news, infos
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
|
||||||
Reset all environments
|
|
||||||
"""
|
|
||||||
obs = self.venv.reset()
|
obs = self.venv.reset()
|
||||||
self.stackedobs[...] = 0
|
self.stackedobs[...] = 0
|
||||||
self.stackedobs[..., -obs.shape[-1]:] = obs
|
self.stackedobs[..., -obs.shape[-1]:] = obs
|
||||||
|
29
baselines/common/vec_env/vec_monitor.py
Normal file
29
baselines/common/vec_env/vec_monitor.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from . import VecEnvWrapper
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class VecMonitor(VecEnvWrapper):
|
||||||
|
def __init__(self, venv):
|
||||||
|
VecEnvWrapper.__init__(self, venv)
|
||||||
|
self.eprets = None
|
||||||
|
self.eplens = None
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
obs = self.venv.reset()
|
||||||
|
self.eprets = np.zeros(self.num_envs, 'f')
|
||||||
|
self.eplens = np.zeros(self.num_envs, 'i')
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def step_wait(self):
|
||||||
|
obs, rews, dones, infos = self.venv.step_wait()
|
||||||
|
self.eprets += rews
|
||||||
|
self.eplens += 1
|
||||||
|
newinfos = []
|
||||||
|
for (i, (done, ret, eplen, info)) in enumerate(zip(dones, self.eprets, self.eplens, infos)):
|
||||||
|
info = info.copy()
|
||||||
|
if done:
|
||||||
|
info['episode'] = {'r': ret, 'l': eplen}
|
||||||
|
self.eprets[i] = 0
|
||||||
|
self.eplens[i] = 0
|
||||||
|
newinfos.append(info)
|
||||||
|
return obs, rews, dones, newinfos
|
@@ -1,17 +1,18 @@
|
|||||||
from baselines.common.vec_env import VecEnvWrapper
|
from . import VecEnvWrapper
|
||||||
from baselines.common.running_mean_std import RunningMeanStd
|
from baselines.common.running_mean_std import RunningMeanStd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class VecNormalize(VecEnvWrapper):
|
class VecNormalize(VecEnvWrapper):
|
||||||
"""
|
"""
|
||||||
Vectorized environment base class
|
A vectorized wrapper that normalizes the observations
|
||||||
|
and returns from an environment.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, venv, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8):
|
def __init__(self, venv, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8):
|
||||||
VecEnvWrapper.__init__(self, venv)
|
VecEnvWrapper.__init__(self, venv)
|
||||||
self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
|
self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
|
||||||
self.ret_rms = RunningMeanStd(shape=()) if ret else None
|
self.ret_rms = RunningMeanStd(shape=()) if ret else None
|
||||||
#self.ob_rms = TfRunningMeanStd(shape=self.observation_space.shape, scope='observation_running_mean_std') if ob else None
|
|
||||||
#self.ret_rms = TfRunningMeanStd(shape=(), scope='return_running_mean_std') if ret else None
|
|
||||||
self.clipob = clipob
|
self.clipob = clipob
|
||||||
self.cliprew = cliprew
|
self.cliprew = cliprew
|
||||||
self.ret = np.zeros(self.num_envs)
|
self.ret = np.zeros(self.num_envs)
|
||||||
@@ -19,12 +20,6 @@ class VecNormalize(VecEnvWrapper):
|
|||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
|
|
||||||
def step_wait(self):
|
def step_wait(self):
|
||||||
"""
|
|
||||||
Apply sequence of actions to sequence of environments
|
|
||||||
actions -> (observations, rewards, news)
|
|
||||||
|
|
||||||
where 'news' is a boolean vector indicating whether each element is new.
|
|
||||||
"""
|
|
||||||
obs, rews, news, infos = self.venv.step_wait()
|
obs, rews, news, infos = self.venv.step_wait()
|
||||||
self.ret = self.ret * self.gamma + rews
|
self.ret = self.ret * self.gamma + rews
|
||||||
obs = self._obfilt(obs)
|
obs = self._obfilt(obs)
|
||||||
@@ -42,8 +37,5 @@ class VecNormalize(VecEnvWrapper):
|
|||||||
return obs
|
return obs
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
|
||||||
Reset all environments
|
|
||||||
"""
|
|
||||||
obs = self.venv.reset()
|
obs = self.venv.reset()
|
||||||
return self._obfilt(obs)
|
return self._obfilt(obs)
|
||||||
|
Reference in New Issue
Block a user