diff --git a/README.md b/README.md index d17074a..88ba054 100644 --- a/README.md +++ b/README.md @@ -102,11 +102,11 @@ The algorithms serialization API is not properly unified yet; however, there is `--save_path` and `--load_path` command-line option loads the tensorflow state from a given path before training, and saves it after the training, respectively. Let's imagine you'd like to train ppo2 on Atari Pong, save the model and then later visualize what has it learnt. ```bash - python -m baselines.run --alg=ppo2 --env=PongNoFrameskip-v4 --num-timesteps=2e7 --save_path=~/models/pong_20M_ppo2 + python -m baselines.run --alg=ppo2 --env=PongNoFrameskip-v4 --num_timesteps=2e7 --save_path=~/models/pong_20M_ppo2 ``` This should get to the mean reward per episode about 5k. To load and visualize the model, we'll do the following - load the model, train it for 0 steps, and then visualize: ```bash - python -m baselines.run --alg=ppo2 --env=PongNoFrameskip-v4 --num-timesteps=0 --load_path=~/models/pong_20M_ppo2 --play + python -m baselines.run --alg=ppo2 --env=PongNoFrameskip-v4 --num_timesteps=0 --load_path=~/models/pong_20M_ppo2 --play ``` *NOTE:* At the moment Mujoco training uses VecNormalize wrapper for the environment which is not being saved correctly; so loading the models trained on Mujoco will not work well if the environment is recreated. If necessary, you can work around that by replacing RunningMeanStd by TfRunningMeanStd in [baselines/common/vec_env/vec_normalize.py](baselines/common/vec_env/vec_normalize.py#L12). This way, mean and std of environment normalizing wrapper will be saved in tensorflow variables and included in the model file; however, training is slower that way - hence not including it by default diff --git a/baselines/a2c/README.md b/baselines/a2c/README.md index 2df6eb2..915852b 100644 --- a/baselines/a2c/README.md +++ b/baselines/a2c/README.md @@ -2,4 +2,5 @@ - Original paper: https://arxiv.org/abs/1602.01783 - Baselines blog post: https://blog.openai.com/baselines-acktr-a2c/ -- `python -m baselines.a2c.run_atari` runs the algorithm for 40M frames = 10M timesteps on an Atari game. See help (`-h`) for more options. \ No newline at end of file +- `python -m baselines.run --alg=a2c --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options +- also refer to the repo-wide [README.md](../../README.md#training-models) diff --git a/baselines/acer/README.md b/baselines/acer/README.md index 7a53d75..d1ef98c 100644 --- a/baselines/acer/README.md +++ b/baselines/acer/README.md @@ -1,4 +1,6 @@ # ACER - Original paper: https://arxiv.org/abs/1611.01224 -- `python -m baselines.acer.run_atari` runs the algorithm for 40M frames = 10M timesteps on an Atari game. See help (`-h`) for more options. \ No newline at end of file +- `python -m baselines.run --alg=acer --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options. +- also refer to the repo-wide [README.md](../../README.md#training-models) + diff --git a/baselines/acktr/README.md b/baselines/acktr/README.md index e8a806d..93692e8 100644 --- a/baselines/acktr/README.md +++ b/baselines/acktr/README.md @@ -2,4 +2,7 @@ - Original paper: https://arxiv.org/abs/1708.05144 - Baselines blog post: https://blog.openai.com/baselines-acktr-a2c/ -- `python -m baselines.acktr.run_atari` runs the algorithm for 40M frames = 10M timesteps on an Atari game. See help (`-h`) for more options. \ No newline at end of file +- `python -m baselines.run --alg=acktr --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options. +- also refer to the repo-wide [README.md](../../README.md#training-models) + + diff --git a/baselines/acktr/run_atari.py b/baselines/acktr/run_atari.py deleted file mode 100644 index 50e1580..0000000 --- a/baselines/acktr/run_atari.py +++ /dev/null @@ -1,23 +0,0 @@ -#!/usr/bin/env python3 - -from functools import partial - -from baselines import logger -from baselines.acktr.acktr_disc import learn -from baselines.common.cmd_util import make_atari_env, atari_arg_parser -from baselines.common.vec_env.vec_frame_stack import VecFrameStack -from baselines.common.policies import cnn - -def train(env_id, num_timesteps, seed, num_cpu): - env = VecFrameStack(make_atari_env(env_id, num_cpu, seed), 4) - policy_fn = cnn(env=env, one_dim_bias=True) - learn(policy_fn, env, seed, total_timesteps=int(num_timesteps * 1.1), nprocs=num_cpu) - env.close() - -def main(): - args = atari_arg_parser().parse_args() - logger.configure() - train(args.env, num_timesteps=args.num_timesteps, seed=args.seed, num_cpu=32) - -if __name__ == '__main__': - main() diff --git a/baselines/common/cmd_util.py b/baselines/common/cmd_util.py index 681a80c..54121df 100644 --- a/baselines/common/cmd_util.py +++ b/baselines/common/cmd_util.py @@ -40,7 +40,8 @@ def make_mujoco_env(env_id, seed, reward_scale=1.0): myseed = seed + 1000 * rank if seed is not None else None set_global_seeds(myseed) env = gym.make(env_id) - env = Monitor(env, os.path.join(logger.get_dir(), str(rank)), allow_early_resets=True) + logger_path = None if logger.get_dir() is None else os.path.join(logger.get_dir(), str(rank)) + env = Monitor(env, logger_path, allow_early_resets=True) env.seed(seed) if reward_scale != 1.0: diff --git a/baselines/common/tests/test_cartpole.py b/baselines/common/tests/test_cartpole.py index 359006c..fe799a3 100644 --- a/baselines/common/tests/test_cartpole.py +++ b/baselines/common/tests/test_cartpole.py @@ -14,7 +14,7 @@ common_kwargs = dict( learn_kwargs = { 'a2c' : dict(nsteps=32, value_network='copy', lr=0.05), 'acktr': dict(nsteps=32, value_network='copy'), - 'deepq': {}, + 'deepq': dict(total_timesteps=20000), 'ppo2': dict(value_network='copy'), 'trpo_mpi': {} } @@ -38,3 +38,6 @@ def test_cartpole(alg): return env reward_per_episode_test(env_fn, learn_fn, 100) + +if __name__ == '__main__': + test_cartpole('deepq') diff --git a/baselines/common/vec_env/__init__.py b/baselines/common/vec_env/__init__.py index eb07310..c2d987b 100644 --- a/baselines/common/vec_env/__init__.py +++ b/baselines/common/vec_env/__init__.py @@ -1,28 +1,34 @@ from abc import ABC, abstractmethod from baselines import logger + class AlreadySteppingError(Exception): """ Raised when an asynchronous step is running while step_async() is called again. """ + def __init__(self): msg = 'already running an async step' Exception.__init__(self, msg) + class NotSteppingError(Exception): """ Raised when an asynchronous step is not running but step_wait() is called. """ + def __init__(self): msg = 'not running an async step' Exception.__init__(self, msg) + class VecEnv(ABC): """ An abstract asynchronous, vectorized environment. """ + def __init__(self, num_envs, observation_space, action_space): self.num_envs = num_envs self.observation_space = observation_space @@ -32,7 +38,7 @@ class VecEnv(ABC): def reset(self): """ Reset all the environments and return an array of - observations, or a tuple of observation arrays. + observations, or a dict of observation arrays. If step_async is still doing work, that work will be cancelled and step_wait() should not be called @@ -58,7 +64,7 @@ class VecEnv(ABC): Wait for the step taken with step_async(). Returns (obs, rews, dones, infos): - - obs: an array of observations, or a tuple of + - obs: an array of observations, or a dict of arrays of observations. - rews: an array of rewards - dones: an array of "episode done" booleans @@ -74,11 +80,16 @@ class VecEnv(ABC): pass def step(self, actions): + """ + Step the environments synchronously. + + This is available for backwards compatibility. + """ self.step_async(actions) return self.step_wait() def render(self, mode='human'): - logger.warn('Render not defined for %s'%self) + logger.warn('Render not defined for %s' % self) @property def unwrapped(self): @@ -87,13 +98,19 @@ class VecEnv(ABC): else: return self + class VecEnvWrapper(VecEnv): + """ + An environment wrapper that applies to an entire batch + of environments at once. + """ + def __init__(self, venv, observation_space=None, action_space=None): self.venv = venv - VecEnv.__init__(self, - num_envs=venv.num_envs, - observation_space=observation_space or venv.observation_space, - action_space=action_space or venv.action_space) + VecEnv.__init__(self, + num_envs=venv.num_envs, + observation_space=observation_space or venv.observation_space, + action_space=action_space or venv.action_space) def step_async(self, actions): self.venv.step_async(actions) @@ -112,15 +129,19 @@ class VecEnvWrapper(VecEnv): def render(self): self.venv.render() + class CloudpickleWrapper(object): """ Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) """ + def __init__(self, x): self.x = x + def __getstate__(self): import cloudpickle return cloudpickle.dumps(self.x) + def __setstate__(self, ob): import pickle self.x = pickle.loads(ob) diff --git a/baselines/common/vec_env/dummy_vec_env.py b/baselines/common/vec_env/dummy_vec_env.py index 477bf30..af7f3d6 100644 --- a/baselines/common/vec_env/dummy_vec_env.py +++ b/baselines/common/vec_env/dummy_vec_env.py @@ -2,27 +2,16 @@ import numpy as np from gym import spaces from collections import OrderedDict from . import VecEnv +from .util import copy_obs_dict, dict_to_obs, obs_space_info class DummyVecEnv(VecEnv): def __init__(self, env_fns): self.envs = [fn() for fn in env_fns] env = self.envs[0] VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) - shapes, dtypes = {}, {} - self.keys = [] obs_space = env.observation_space - - if isinstance(obs_space, spaces.Dict): - assert isinstance(obs_space.spaces, OrderedDict) - subspaces = obs_space.spaces - else: - subspaces = {None: obs_space} - - for key, box in subspaces.items(): - shapes[key] = box.shape - dtypes[key] = box.dtype - self.keys.append(key) - + + self.keys, shapes, dtypes = obs_space_info(obs_space) self.buf_obs = { k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys } self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool) self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) @@ -53,7 +42,7 @@ class DummyVecEnv(VecEnv): if self.buf_dones[e]: obs = self.envs[e].reset() self._save_obs(e, obs) - return (np.copy(self._obs_from_buf()), np.copy(self.buf_rews), np.copy(self.buf_dones), + return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), self.buf_infos.copy()) def reset(self): @@ -76,7 +65,5 @@ class DummyVecEnv(VecEnv): self.buf_obs[k][e] = obs[k] def _obs_from_buf(self): - if self.keys==[None]: - return self.buf_obs[None] - else: - return self.buf_obs + return dict_to_obs(copy_obs_dict(self.buf_obs)) + diff --git a/baselines/common/vec_env/shmem_vec_env.py b/baselines/common/vec_env/shmem_vec_env.py new file mode 100644 index 0000000..76f2e81 --- /dev/null +++ b/baselines/common/vec_env/shmem_vec_env.py @@ -0,0 +1,152 @@ +""" +An interface for asynchronous vectorized environments. +""" + +from multiprocessing import Pipe, Array, Process +import numpy as np +from . import VecEnv, CloudpickleWrapper +import ctypes +from gym.envs.classic_control import rendering +from baselines import logger +from baselines.common.tile_images import tile_images + +from .util import dict_to_obs, obs_space_info, obs_to_dict + +_NP_TO_CT = {np.float32: ctypes.c_float, + np.int32: ctypes.c_int32, + np.int8: ctypes.c_int8, + np.uint8: ctypes.c_char, + np.bool: ctypes.c_bool} + + +class ShmemVecEnv(VecEnv): + """ + An AsyncEnv that uses multiprocessing to run multiple + environments in parallel. + """ + + def __init__(self, env_fns, spaces=None): + """ + If you don't specify observation_space, we'll have to create a dummy + environment to get it. + """ + if spaces: + observation_space, action_space = spaces + else: + logger.log('Creating dummy env object to get spaces') + with logger.scoped_configure(format_strs=[]): + dummy = env_fns[0]() + observation_space, action_space = dummy.observation_space, dummy.action_space + dummy.close() + del dummy + VecEnv.__init__(self, len(env_fns), observation_space, action_space) + self.obs_keys, self.obs_shapes, self.obs_dtypes = obs_space_info(observation_space) + self.obs_bufs = [ + {k: Array(_NP_TO_CT[self.obs_dtypes[k].type], int(np.prod(self.obs_shapes[k]))) for k in self.obs_keys} + for _ in env_fns] + self.parent_pipes = [] + self.procs = [] + for env_fn, obs_buf in zip(env_fns, self.obs_bufs): + wrapped_fn = CloudpickleWrapper(env_fn) + parent_pipe, child_pipe = Pipe() + proc = Process(target=_subproc_worker, + args=(child_pipe, parent_pipe, wrapped_fn, obs_buf, self.obs_shapes, self.obs_dtypes, self.obs_keys)) + proc.daemon = True + self.procs.append(proc) + self.parent_pipes.append(parent_pipe) + proc.start() + child_pipe.close() + self.waiting_step = False + self.viewer = None + + def reset(self): + if self.waiting_step: + logger.warn('Called reset() while waiting for the step to complete') + self.step_wait() + for pipe in self.parent_pipes: + pipe.send(('reset', None)) + return self._decode_obses([pipe.recv() for pipe in self.parent_pipes]) + + def step_async(self, actions): + assert len(actions) == len(self.parent_pipes) + for pipe, act in zip(self.parent_pipes, actions): + pipe.send(('step', act)) + + def step_wait(self): + outs = [pipe.recv() for pipe in self.parent_pipes] + obs, rews, dones, infos = zip(*outs) + return self._decode_obses(obs), np.array(rews), np.array(dones), infos + + def close(self): + if self.waiting_step: + self.step_wait() + for pipe in self.parent_pipes: + pipe.send(('close', None)) + for pipe in self.parent_pipes: + pipe.recv() + pipe.close() + for proc in self.procs: + proc.join() + if self.viewer is not None: + self.viewer.close() + + def render(self, mode='human'): + for pipe in self.parent_pipes: + pipe.send(('render', None)) + imgs = [pipe.recv() for pipe in self.parent_pipes] + bigimg = tile_images(imgs) + if mode == 'human': + if self.viewer is None: + self.viewer = rendering.SimpleImageViewer() + + self.viewer.imshow(bigimg[:, :, ::-1]) + elif mode == 'rgb_array': + return bigimg + else: + raise NotImplementedError + + def _decode_obses(self, obs): + result = {} + for k in self.obs_keys: + + bufs = [b[k] for b in self.obs_bufs] + o = [np.frombuffer(b.get_obj(), dtype=self.obs_dtypes[k]).reshape(self.obs_shapes[k]) for b in bufs] + result[k] = np.array(o) + return dict_to_obs(result) + + +def _subproc_worker(pipe, parent_pipe, env_fn_wrapper, obs_bufs, obs_shapes, obs_dtypes, keys): + """ + Control a single environment instance using IPC and + shared memory. + """ + def _write_obs(maybe_dict_obs): + flatdict = obs_to_dict(maybe_dict_obs) + for k in keys: + dst = obs_bufs[k].get_obj() + dst_np = np.frombuffer(dst, dtype=obs_dtypes[k]).reshape(obs_shapes[k]) # pylint: disable=W0212 + np.copyto(dst_np, flatdict[k]) + + env = env_fn_wrapper.x() + parent_pipe.close() + try: + while True: + cmd, data = pipe.recv() + if cmd == 'reset': + pipe.send(_write_obs(env.reset())) + elif cmd == 'step': + obs, reward, done, info = env.step(data) + if done: + obs = env.reset() + pipe.send((_write_obs(obs), reward, done, info)) + elif cmd == 'render': + pipe.send(env.render(mode='rgb_array')) + elif cmd == 'close': + pipe.send(None) + break + else: + raise RuntimeError('Got unrecognized cmd %s' % cmd) + except KeyboardInterrupt: + print('ShmemVecEnv worker: got KeyboardInterrupt') + finally: + env.close() diff --git a/baselines/common/vec_env/subproc_vec_env.py b/baselines/common/vec_env/subproc_vec_env.py index e5b5b32..2f52096 100644 --- a/baselines/common/vec_env/subproc_vec_env.py +++ b/baselines/common/vec_env/subproc_vec_env.py @@ -1,7 +1,8 @@ import numpy as np from multiprocessing import Process, Pipe -from baselines.common.vec_env import VecEnv, CloudpickleWrapper +from . import VecEnv, CloudpickleWrapper from baselines.common.tile_images import tile_images +from gym.envs.classic_control import rendering def worker(remote, parent_remote, env_fn_wrapper): @@ -32,6 +33,7 @@ def worker(remote, parent_remote, env_fn_wrapper): finally: env.close() + class SubprocVecEnv(VecEnv): def __init__(self, env_fns, spaces=None): """ @@ -42,15 +44,16 @@ class SubprocVecEnv(VecEnv): nenvs = len(env_fns) self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)]) self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) - for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] + for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] for p in self.ps: - p.daemon = True # if the main process crashes, we should not cause things to hang + p.daemon = True # if the main process crashes, we should not cause things to hang p.start() for remote in self.work_remotes: remote.close() self.remotes[0].send(('get_spaces', None)) observation_space, action_space = self.remotes[0].recv() + self.viewer = None VecEnv.__init__(self, len(env_fns), observation_space, action_space) def step_async(self, actions): @@ -78,12 +81,14 @@ class SubprocVecEnv(VecEnv): if self.closed: return if self.waiting: - for remote in self.remotes: + for remote in self.remotes: remote.recv() for remote in self.remotes: remote.send(('close', None)) for p in self.ps: p.join() + if self.viewer is not None: + self.viewer.close() self.closed = True def render(self, mode='human'): @@ -92,10 +97,12 @@ class SubprocVecEnv(VecEnv): imgs = [pipe.recv() for pipe in self.remotes] bigimg = tile_images(imgs) if mode == 'human': - import cv2 - cv2.imshow('vecenv', bigimg[:,:,::-1]) - cv2.waitKey(1) + if self.viewer is None: + self.viewer = rendering.SimpleImageViewer() + + self.viewer.imshow(bigimg[:, :, ::-1]) + elif mode == 'rgb_array': return bigimg else: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/baselines/common/vec_env/test_vec_env.py b/baselines/common/vec_env/test_vec_env.py new file mode 100644 index 0000000..6d0d41c --- /dev/null +++ b/baselines/common/vec_env/test_vec_env.py @@ -0,0 +1,85 @@ +""" +Tests for asynchronous vectorized environments. +""" + +import gym +import numpy as np +import pytest +from .dummy_vec_env import DummyVecEnv +from .shmem_vec_env import ShmemVecEnv +from .subproc_vec_env import SubprocVecEnv + + +@pytest.mark.parametrize('klass', (ShmemVecEnv, SubprocVecEnv)) +@pytest.mark.parametrize('dtype', ('uint8', 'float32')) +def test_vec_env(klass, dtype): # pylint: disable=R0914 + """ + Test that a vectorized environment is equivalent to + DummyVecEnv, since DummyVecEnv is less likely to be + error prone. + """ + num_envs = 3 + num_steps = 100 + shape = (3, 8) + + def make_fn(seed): + """ + Get an environment constructor with a seed. + """ + return lambda: _SimpleEnv(seed, shape, dtype) + fns = [make_fn(i) for i in range(num_envs)] + env1 = DummyVecEnv(fns) + env2 = klass(fns) + try: + obs1, obs2 = env1.reset(), env2.reset() + assert np.array(obs1).shape == np.array(obs2).shape + assert np.allclose(obs1, obs2) + np.random.seed(1337) + for _ in range(num_steps): + joint_shape = (len(fns),) + shape + actions = np.array(np.random.randint(0, 0x100, size=joint_shape), + dtype=dtype) + for env in [env1, env2]: + env.step_async(actions) + outs1 = env1.step_wait() + outs2 = env2.step_wait() + for out1, out2 in zip(outs1[:3], outs2[:3]): + assert np.array(out1).shape == np.array(out2).shape + assert np.allclose(out1, out2) + assert list(outs1[3]) == list(outs2[3]) + finally: + env1.close() + env2.close() + + +class _SimpleEnv(gym.Env): + """ + An environment with a pre-determined observation space + and RNG seed. + """ + + def __init__(self, seed, shape, dtype): + np.random.seed(seed) + self._dtype = dtype + self._start_obs = np.array(np.random.randint(0, 0x100, size=shape), + dtype=dtype) + self._max_steps = seed + 1 + self._cur_obs = None + self._cur_step = 0 + self.action_space = gym.spaces.Box(low=0, high=100, shape=shape, dtype=dtype) + self.observation_space = self.action_space + + def step(self, action): + self._cur_obs += np.array(action, dtype=self._dtype) + self._cur_step += 1 + done = self._cur_step >= self._max_steps + reward = self._cur_step / self._max_steps + return self._cur_obs, reward, done, {'foo': 'bar' + str(reward)} + + def reset(self): + self._cur_obs = self._start_obs + self._cur_step = 0 + return self._cur_obs + + def render(self, mode=None): + raise NotImplementedError diff --git a/baselines/common/vec_env/util.py b/baselines/common/vec_env/util.py new file mode 100644 index 0000000..d29d8a3 --- /dev/null +++ b/baselines/common/vec_env/util.py @@ -0,0 +1,59 @@ +""" +Helpers for dealing with vectorized environments. +""" + +from collections import OrderedDict + +import gym +import numpy as np + + +def copy_obs_dict(obs): + """ + Deep-copy an observation dict. + """ + return {k: np.copy(v) for k, v in obs.items()} + + +def dict_to_obs(obs_dict): + """ + Convert an observation dict into a raw array if the + original observation space was not a Dict space. + """ + if set(obs_dict.keys()) == {None}: + return obs_dict[None] + return obs_dict + + +def obs_space_info(obs_space): + """ + Get dict-structured information about a gym.Space. + + Returns: + A tuple (keys, shapes, dtypes): + keys: a list of dict keys. + shapes: a dict mapping keys to shapes. + dtypes: a dict mapping keys to dtypes. + """ + if isinstance(obs_space, gym.spaces.Dict): + assert isinstance(obs_space.spaces, OrderedDict) + subspaces = obs_space.spaces + else: + subspaces = {None: obs_space} + keys = [] + shapes = {} + dtypes = {} + for key, box in subspaces.items(): + keys.append(key) + shapes[key] = box.shape + dtypes[key] = box.dtype + return keys, shapes, dtypes + + +def obs_to_dict(obs): + """ + Convert an observation into a dict. + """ + if isinstance(obs, dict): + return obs + return {None: obs} diff --git a/baselines/common/vec_env/vec_frame_stack.py b/baselines/common/vec_env/vec_frame_stack.py index 0bbcbdb..9185873 100644 --- a/baselines/common/vec_env/vec_frame_stack.py +++ b/baselines/common/vec_env/vec_frame_stack.py @@ -1,18 +1,16 @@ -from baselines.common.vec_env import VecEnvWrapper +from . import VecEnvWrapper import numpy as np from gym import spaces + class VecFrameStack(VecEnvWrapper): - """ - Vectorized environment base class - """ def __init__(self, venv, nstack): self.venv = venv self.nstack = nstack - wos = venv.observation_space # wrapped ob space + wos = venv.observation_space # wrapped ob space low = np.repeat(wos.low, self.nstack, axis=-1) high = np.repeat(wos.high, self.nstack, axis=-1) - self.stackedobs = np.zeros((venv.num_envs,)+low.shape, low.dtype) + self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype) observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype) VecEnvWrapper.__init__(self, venv, observation_space=observation_space) @@ -26,9 +24,6 @@ class VecFrameStack(VecEnvWrapper): return self.stackedobs, rews, news, infos def reset(self): - """ - Reset all environments - """ obs = self.venv.reset() self.stackedobs[...] = 0 self.stackedobs[..., -obs.shape[-1]:] = obs diff --git a/baselines/common/vec_env/vec_monitor.py b/baselines/common/vec_env/vec_monitor.py new file mode 100644 index 0000000..0074aee --- /dev/null +++ b/baselines/common/vec_env/vec_monitor.py @@ -0,0 +1,29 @@ +from . import VecEnvWrapper +import numpy as np + + +class VecMonitor(VecEnvWrapper): + def __init__(self, venv): + VecEnvWrapper.__init__(self, venv) + self.eprets = None + self.eplens = None + + def reset(self): + obs = self.venv.reset() + self.eprets = np.zeros(self.num_envs, 'f') + self.eplens = np.zeros(self.num_envs, 'i') + return obs + + def step_wait(self): + obs, rews, dones, infos = self.venv.step_wait() + self.eprets += rews + self.eplens += 1 + newinfos = [] + for (i, (done, ret, eplen, info)) in enumerate(zip(dones, self.eprets, self.eplens, infos)): + info = info.copy() + if done: + info['episode'] = {'r': ret, 'l': eplen} + self.eprets[i] = 0 + self.eplens[i] = 0 + newinfos.append(info) + return obs, rews, dones, newinfos diff --git a/baselines/common/vec_env/vec_normalize.py b/baselines/common/vec_env/vec_normalize.py index 5d5c5ad..cd80e20 100644 --- a/baselines/common/vec_env/vec_normalize.py +++ b/baselines/common/vec_env/vec_normalize.py @@ -1,17 +1,18 @@ -from baselines.common.vec_env import VecEnvWrapper +from . import VecEnvWrapper from baselines.common.running_mean_std import RunningMeanStd import numpy as np + class VecNormalize(VecEnvWrapper): """ - Vectorized environment base class + A vectorized wrapper that normalizes the observations + and returns from an environment. """ + def __init__(self, venv, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8): VecEnvWrapper.__init__(self, venv) self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None self.ret_rms = RunningMeanStd(shape=()) if ret else None - #self.ob_rms = TfRunningMeanStd(shape=self.observation_space.shape, scope='observation_running_mean_std') if ob else None - #self.ret_rms = TfRunningMeanStd(shape=(), scope='return_running_mean_std') if ret else None self.clipob = clipob self.cliprew = cliprew self.ret = np.zeros(self.num_envs) @@ -19,12 +20,6 @@ class VecNormalize(VecEnvWrapper): self.epsilon = epsilon def step_wait(self): - """ - Apply sequence of actions to sequence of environments - actions -> (observations, rewards, news) - - where 'news' is a boolean vector indicating whether each element is new. - """ obs, rews, news, infos = self.venv.step_wait() self.ret = self.ret * self.gamma + rews obs = self._obfilt(obs) @@ -42,8 +37,5 @@ class VecNormalize(VecEnvWrapper): return obs def reset(self): - """ - Reset all environments - """ obs = self.venv.reset() return self._obfilt(obs) diff --git a/baselines/deepq/README.md b/baselines/deepq/README.md index 8fa19ab..51fcbef 100644 --- a/baselines/deepq/README.md +++ b/baselines/deepq/README.md @@ -9,44 +9,29 @@ Here's a list of commands to run to quickly get a working example: ```bash # Train model and save the results to cartpole_model.pkl -python -m baselines.deepq.experiments.train_cartpole +python -m baselines.run --alg=deepq --env=CartPole-v0 --save_path=./cartpole_model.pkl --num_timesteps=1e5 # Load the model saved in cartpole_model.pkl and visualize the learned policy -python -m baselines.deepq.experiments.enjoy_cartpole +python -m baselines.run --alg=deepq --env=CartPole-v0 --load_path=./cartpole_model.pkl --num_timesteps=0 --play ``` - -Be sure to check out the source code of [both](experiments/train_cartpole.py) [files](experiments/enjoy_cartpole.py)! - ## If you wish to apply DQN to solve a problem. Check out our simple agent trained with one stop shop `deepq.learn` function. - [baselines/deepq/experiments/train_cartpole.py](experiments/train_cartpole.py) - train a Cartpole agent. -- [baselines/deepq/experiments/train_pong.py](experiments/train_pong.py) - train a Pong agent using convolutional neural networks. -In particular notice that once `deepq.learn` finishes training it returns `act` function which can be used to select actions in the environment. Once trained you can easily save it and load at later time. For both of the files listed above there are complimentary files `enjoy_cartpole.py` and `enjoy_pong.py` respectively, that load and visualize the learned policy. +In particular notice that once `deepq.learn` finishes training it returns `act` function which can be used to select actions in the environment. Once trained you can easily save it and load at later time. Complimentary file `enjoy_cartpole.py` loads and visualizes the learned policy. ## If you wish to experiment with the algorithm ##### Check out the examples - - [baselines/deepq/experiments/custom_cartpole.py](experiments/custom_cartpole.py) - Cartpole training with more fine grained control over the internals of DQN algorithm. -- [baselines/deepq/experiments/run_atari.py](experiments/run_atari.py) - more robust setup for training at scale. - - -##### Download a pretrained Atari agent - -For some research projects it is sometimes useful to have an already trained agent handy. There's a variety of models to choose from. You can list them all by running: +- [baselines/deepq/defaults.py](defaults.py) - settings for training on atari. Run ```bash -python -m baselines.deepq.experiments.atari.download_model +python -m baselines.run --alg=deepq --env=PongNoFrameskip-v4 ``` +to train on Atari Pong (see more in repo-wide [README.md](../../README.md#training-models)) -Once you pick a model, you can download it and visualize the learned policy. Be sure to pass `--dueling` flag to visualization script when using dueling models. -```bash -python -m baselines.deepq.experiments.atari.download_model --blob model-atari-duel-pong-1 --model-dir /tmp/models -python -m baselines.deepq.experiments.atari.enjoy --model-dir /tmp/models/model-atari-duel-pong-1 --env Pong --dueling - -``` diff --git a/baselines/deepq/build_graph.py b/baselines/deepq/build_graph.py index e9ff1a4..dd96f0e 100644 --- a/baselines/deepq/build_graph.py +++ b/baselines/deepq/build_graph.py @@ -309,7 +309,7 @@ def build_act_with_param_noise(make_obs_ph, q_func, num_actions, scope="deepq", outputs=output_actions, givens={update_eps_ph: -1.0, stochastic_ph: True, reset_ph: False, update_param_noise_threshold_ph: False, update_param_noise_scale_ph: False}, updates=updates) - def act(ob, reset, update_param_noise_threshold, update_param_noise_scale, stochastic=True, update_eps=-1): + def act(ob, reset=False, update_param_noise_threshold=False, update_param_noise_scale=False, stochastic=True, update_eps=-1): return _act(ob, stochastic, update_eps, reset, update_param_noise_threshold, update_param_noise_scale) return act diff --git a/baselines/deepq/deepq.py b/baselines/deepq/deepq.py index 7d44acf..01921bb 100644 --- a/baselines/deepq/deepq.py +++ b/baselines/deepq/deepq.py @@ -27,7 +27,7 @@ class ActWrapper(object): self.initial_state = None @staticmethod - def load_act(self, path): + def load_act(path): with open(path, "rb") as f: model_data, act_params = cloudpickle.load(f) act = deepq.build_act(**act_params) @@ -70,6 +70,7 @@ class ActWrapper(object): def save(self, path): save_state(path) + self.save_act(path+".pickle") def load_act(path): @@ -194,8 +195,9 @@ def learn(env, # capture the shape outside the closure so that the env object is not serialized # by cloudpickle when serializing make_obs_ph + observation_space = env.observation_space def make_obs_ph(name): - return ObservationInput(env.observation_space, name=name) + return ObservationInput(observation_space, name=name) act, train, update_target, debug = deepq.build_train( make_obs_ph=make_obs_ph, diff --git a/baselines/deepq/experiments/run_atari.py b/baselines/deepq/experiments/run_atari.py index b6b427b..aa60001 100644 --- a/baselines/deepq/experiments/run_atari.py +++ b/baselines/deepq/experiments/run_atari.py @@ -23,17 +23,15 @@ def main(): env = make_atari(args.env) env = bench.Monitor(env, logger.get_dir()) env = deepq.wrap_atari_dqn(env) - model = deepq.models.cnn_to_mlp( - convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], - hiddens=[256], - dueling=bool(args.dueling), - ) deepq.learn( env, - q_func=model, + "conv_only", + convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)], + hiddens=[256], + dueling=bool(args.dueling), lr=1e-4, - max_timesteps=args.num_timesteps, + total_timesteps=args.num_timesteps, buffer_size=10000, exploration_fraction=0.1, exploration_final_eps=0.01, diff --git a/baselines/deepq/experiments/train_cartpole.py b/baselines/deepq/experiments/train_cartpole.py index a50c242..cfbbdc9 100644 --- a/baselines/deepq/experiments/train_cartpole.py +++ b/baselines/deepq/experiments/train_cartpole.py @@ -11,12 +11,11 @@ def callback(lcl, _glb): def main(): env = gym.make("CartPole-v0") - model = deepq.models.mlp([64]) act = deepq.learn( env, - q_func=model, + network='mlp', lr=1e-3, - max_timesteps=100000, + total_timesteps=100000, buffer_size=50000, exploration_fraction=0.1, exploration_final_eps=0.02, diff --git a/baselines/ppo2/README.md b/baselines/ppo2/README.md index 4c262ad..4d431bc 100644 --- a/baselines/ppo2/README.md +++ b/baselines/ppo2/README.md @@ -2,5 +2,7 @@ - Original paper: https://arxiv.org/abs/1707.06347 - Baselines blog post: https://blog.openai.com/openai-baselines-ppo/ -- `python -m baselines.ppo2.run_atari` runs the algorithm for 40M frames = 10M timesteps on an Atari game. See help (`-h`) for more options. -- `python -m baselines.ppo2.run_mujoco` runs the algorithm for 1M frames on a Mujoco environment. + +- `python -m baselines.run --alg=ppo2 --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options. +- `python -m baselines.run --alg=ppo2 --env=Ant-v2 --num_timesteps=1e6` runs the algorithm for 1M frames on a Mujoco Ant environment. +- also refer to the repo-wide [README.md](../../README.md#training-models) diff --git a/baselines/run.py b/baselines/run.py index cba8515..15424b9 100644 --- a/baselines/run.py +++ b/baselines/run.py @@ -5,6 +5,7 @@ import os.path as osp import gym from collections import defaultdict import tensorflow as tf +import numpy as np from baselines.common.vec_env.vec_frame_stack import VecFrameStack from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_mujoco_env, make_atari_env @@ -75,10 +76,10 @@ def train(args, extra_args): return model, env -def build_env(args, render=False): +def build_env(args): ncpu = multiprocessing.cpu_count() if sys.platform == 'darwin': ncpu //= 2 - nenv = args.num_env or ncpu if not render else 1 + nenv = args.num_env or ncpu alg = args.alg rank = MPI.COMM_WORLD.Get_rank() if MPI else 0 seed = args.seed @@ -123,14 +124,18 @@ def build_env(args, render=False): env = bench.Monitor(env, logger.get_dir()) env = retro_wrappers.wrap_deepmind_retro(env) - elif env_type == 'classic': + elif env_type == 'classic_control': def make_env(): e = gym.make(env_id) + e = bench.Monitor(e, logger.get_dir(), allow_early_resets=True) e.seed(seed) return e env = DummyVecEnv([make_env]) - + + else: + raise ValueError('Unknown env_type {}'.format(env_type)) + return env @@ -149,7 +154,7 @@ def get_env_type(env_id): return env_type, env_id def get_default_network(env_type): - if env_type == 'mujoco' or env_type=='classic': + if env_type == 'mujoco' or env_type == 'classic_control': return 'mlp' if env_type == 'atari': return 'cnn' @@ -215,12 +220,14 @@ def main(): if args.play: logger.log("Running trained model") - env = build_env(args, render=True) + env = build_env(args) obs = env.reset() while True: actions = model.step(obs)[0] obs, _, done, _ = env.step(actions) env.render() + done = done.any() if isinstance(done, np.ndarray) else done + if done: obs = env.reset() diff --git a/baselines/trpo_mpi/README.md b/baselines/trpo_mpi/README.md index b3d9b9d..4cdbb5a 100644 --- a/baselines/trpo_mpi/README.md +++ b/baselines/trpo_mpi/README.md @@ -2,5 +2,6 @@ - Original paper: https://arxiv.org/abs/1502.05477 - Baselines blog post https://blog.openai.com/openai-baselines-ppo/ -- `mpirun -np 16 python -m baselines.trpo_mpi.run_atari` runs the algorithm for 40M frames = 10M timesteps on an Atari game. See help (`-h`) for more options. -- `python -m baselines.trpo_mpi.run_mujoco` runs the algorithm for 1M timesteps on a Mujoco environment. +- `mpirun -np 16 python -m baselines.run --alg=trpo_mpi --env=PongNoFrameskip-v4` runs the algorithm for 40M frames = 10M timesteps on an Atari Pong. See help (`-h`) for more options. +- `python -m baselines.run --alg=trpo_mpi --env=Ant-v2 --num_timesteps=1e6` runs the algorithm for 1M timesteps on a Mujoco Ant environment. +- also refer to the repo-wide [README.md](../../README.md#training-models) diff --git a/baselines/trpo_mpi/defaults.py b/baselines/trpo_mpi/defaults.py index 96b6cb3..0b58d18 100644 --- a/baselines/trpo_mpi/defaults.py +++ b/baselines/trpo_mpi/defaults.py @@ -1,4 +1,4 @@ -from rl_common.models import mlp, cnn_small +from baselines.common.models import mlp, cnn_small def atari(): diff --git a/conftest.py b/conftest.py deleted file mode 100644 index 3493c45..0000000 --- a/conftest.py +++ /dev/null @@ -1,19 +0,0 @@ -import pytest - - -def pytest_addoption(parser): - parser.addoption('--runslow', action='store_true', default=False, help='run slow tests') - - -def pytest_collection_modifyitems(config, items): - if config.getoption('--runslow'): - # --runslow given in cli: do not skip slow tests - return - skip_slow = pytest.mark.skip(reason='need --runslow option to run') - slow_tests = [] - for item in items: - if 'slow' in item.keywords: - slow_tests.append(item.name) - item.add_marker(skip_slow) - - print('skipping slow tests', ' '.join(slow_tests), 'use --runslow to run this')