diff --git a/.travis.yml b/.travis.yml index 712fc84..7ca7e6a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,4 +11,4 @@ install: script: - flake8 . --show-source --statistics - - docker run baselines-test pytest -v . + - docker run baselines-test pytest -v --forked . diff --git a/Dockerfile b/Dockerfile index 49a9c79..12e67be 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,8 @@ FROM python:3.6 +RUN apt-get -y update && apt-get -y install ffmpeg # RUN apt-get -y update && apt-get -y install git wget python-dev python3-dev libopenmpi-dev python-pip zlib1g-dev cmake python-opencv + ENV CODE_DIR /root/code COPY . $CODE_DIR/baselines diff --git a/baselines/common/cmd_util.py b/baselines/common/cmd_util.py index 162e34d..90b9868 100644 --- a/baselines/common/cmd_util.py +++ b/baselines/common/cmd_util.py @@ -131,6 +131,8 @@ def common_arg_parser(): parser.add_argument('--num_env', help='Number of environment copies being run in parallel. When not specified, set to number of cpus for Atari, and to 1 for Mujoco', default=None, type=int) parser.add_argument('--reward_scale', help='Reward scale factor. Default: 1.0', default=1.0, type=float) parser.add_argument('--save_path', help='Path to save trained model to', default=None, type=str) + parser.add_argument('--save_video_interval', help='Save video every x steps (0 = disabled)', default=0, type=int) + parser.add_argument('--save_video_length', help='Length of recorded video. Default: 200', default=200, type=int) parser.add_argument('--play', default=False, action='store_true') return parser diff --git a/baselines/common/plot_util.py b/baselines/common/plot_util.py index 6f4c272..56f6f98 100644 --- a/baselines/common/plot_util.py +++ b/baselines/common/plot_util.py @@ -240,6 +240,8 @@ def plot_results( split_fn=default_split_fn, group_fn=default_split_fn, average_group=False, + shaded_std=True, + shaded_err=True, figsize=None, legend_outside=False, resample=0, @@ -346,8 +348,10 @@ def plot_results( ystderr = ystd / np.sqrt(len(ys)) l, = axarr[isplit][0].plot(usex, ymean, color=color) g2l[group] = l - ax.fill_between(usex, ymean - ystderr, ymean + ystderr, color=color, alpha=.4) - ax.fill_between(usex, ymean - ystd, ymean + ystd, color=color, alpha=.2) + if shaded_err: + ax.fill_between(usex, ymean - ystderr, ymean + ystderr, color=color, alpha=.4) + if shaded_std: + ax.fill_between(usex, ymean - ystd, ymean + ystd, color=color, alpha=.2) # https://matplotlib.org/users/legend_guide.html diff --git a/baselines/common/vec_env/__init__.py b/baselines/common/vec_env/__init__.py index cb60531..075a139 100644 --- a/baselines/common/vec_env/__init__.py +++ b/baselines/common/vec_env/__init__.py @@ -32,6 +32,11 @@ class VecEnv(ABC): """ closed = False viewer = None + + metadata = { + 'render.modes': ['human', 'rgb_array'] + } + def __init__(self, num_envs, observation_space, action_space): self.num_envs = num_envs self.observation_space = observation_space diff --git a/baselines/common/vec_env/dummy_vec_env.py b/baselines/common/vec_env/dummy_vec_env.py index 2b4d2ba..c2b86dd 100644 --- a/baselines/common/vec_env/dummy_vec_env.py +++ b/baselines/common/vec_env/dummy_vec_env.py @@ -20,9 +20,6 @@ class DummyVecEnv(VecEnv): env = self.envs[0] VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) obs_space = env.observation_space - if isinstance(obs_space, spaces.MultiDiscrete): - obs_space.shape = obs_space.shape[0] - self.keys, shapes, dtypes = obs_space_info(obs_space) self.buf_obs = { k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys } @@ -79,6 +76,6 @@ class DummyVecEnv(VecEnv): def render(self, mode='human'): if self.num_envs == 1: - self.envs[0].render(mode=mode) + return self.envs[0].render(mode=mode) else: - super().render(mode=mode) + return super().render(mode=mode) diff --git a/baselines/common/vec_env/test_video_recorder.py b/baselines/common/vec_env/test_video_recorder.py new file mode 100644 index 0000000..363404a --- /dev/null +++ b/baselines/common/vec_env/test_video_recorder.py @@ -0,0 +1,49 @@ +""" +Tests for asynchronous vectorized environments. +""" + +import gym +import pytest +import os +import glob +import tempfile + +from .dummy_vec_env import DummyVecEnv +from .shmem_vec_env import ShmemVecEnv +from .subproc_vec_env import SubprocVecEnv +from .vec_video_recorder import VecVideoRecorder + +@pytest.mark.parametrize('klass', (DummyVecEnv, ShmemVecEnv, SubprocVecEnv)) +@pytest.mark.parametrize('num_envs', (1, 4)) +@pytest.mark.parametrize('video_length', (10, 100)) +@pytest.mark.parametrize('video_interval', (1, 50)) +def test_video_recorder(klass, num_envs, video_length, video_interval): + """ + Wrap an existing VecEnv with VevVideoRecorder, + Make (video_interval + video_length + 1) steps, + then check that the file is present + """ + + def make_fn(): + env = gym.make('PongNoFrameskip-v4') + return env + fns = [make_fn for _ in range(num_envs)] + env = klass(fns) + + with tempfile.TemporaryDirectory() as video_path: + env = VecVideoRecorder(env, video_path, record_video_trigger=lambda x: x % video_interval == 0, video_length=video_length) + + env.reset() + for _ in range(video_interval + video_length + 1): + env.step([0] * num_envs) + env.close() + + + recorded_video = glob.glob(os.path.join(video_path, "*.mp4")) + + # first and second step + assert len(recorded_video) == 2 + # Files are not empty + assert all(os.stat(p).st_size != 0 for p in recorded_video) + + diff --git a/baselines/common/vec_env/vec_video_recorder.py b/baselines/common/vec_env/vec_video_recorder.py new file mode 100644 index 0000000..b4e7059 --- /dev/null +++ b/baselines/common/vec_env/vec_video_recorder.py @@ -0,0 +1,89 @@ +import os +from baselines import logger +from baselines.common.vec_env import VecEnvWrapper +from gym.wrappers.monitoring import video_recorder + + +class VecVideoRecorder(VecEnvWrapper): + """ + Wrap VecEnv to record rendered image as mp4 video. + """ + + def __init__(self, venv, directory, record_video_trigger, video_length=200): + """ + # Arguments + venv: VecEnv to wrap + directory: Where to save videos + record_video_trigger: + Function that defines when to start recording. + The function takes the current number of step, + and returns whether we should start recording or not. + video_length: Length of recorded video + """ + + VecEnvWrapper.__init__(self, venv) + self.record_video_trigger = record_video_trigger + self.video_recorder = None + + self.directory = os.path.abspath(directory) + if not os.path.exists(self.directory): os.mkdir(self.directory) + + self.file_prefix = "vecenv" + self.file_infix = '{}'.format(os.getpid()) + self.step_id = 0 + self.video_length = video_length + + self.recording = False + self.recorded_frames = 0 + + def reset(self): + obs = self.venv.reset() + + self.start_video_recorder() + + return obs + + def start_video_recorder(self): + self.close_video_recorder() + + base_path = os.path.join(self.directory, '{}.video.{}.video{:06}'.format(self.file_prefix, self.file_infix, self.step_id)) + self.video_recorder = video_recorder.VideoRecorder( + env=self.venv, + base_path=base_path, + metadata={'step_id': self.step_id} + ) + + self.video_recorder.capture_frame() + self.recorded_frames = 1 + self.recording = True + + def _video_enabled(self): + return self.record_video_trigger(self.step_id) + + def step_wait(self): + obs, rews, dones, infos = self.venv.step_wait() + + self.step_id += 1 + if self.recording: + self.video_recorder.capture_frame() + self.recorded_frames += 1 + if self.recorded_frames > self.video_length: + logger.info("Saving video to ", self.video_recorder.path) + self.close_video_recorder() + elif self._video_enabled(): + self.start_video_recorder() + + return obs, rews, dones, infos + + def close_video_recorder(self): + if self.recording: + self.video_recorder.close() + self.recording = False + self.recorded_frames = 0 + + def close(self): + VecEnvWrapper.close(self) + self.close_video_recorder() + + def __del__(self): + self.close() diff --git a/baselines/run.py b/baselines/run.py index 28cf620..c0298f3 100644 --- a/baselines/run.py +++ b/baselines/run.py @@ -6,6 +6,7 @@ from collections import defaultdict import tensorflow as tf import numpy as np +from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder from baselines.common.vec_env.vec_frame_stack import VecFrameStack from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env from baselines.common.tf_util import get_session @@ -62,6 +63,8 @@ def train(args, extra_args): alg_kwargs.update(extra_args) env = build_env(args) + if args.save_video_interval != 0: + env = VecVideoRecorder(env, osp.join(logger.Logger.CURRENT.dir, "videos"), record_video_trigger=lambda x: x % args.save_video_interval == 0, video_length=args.save_video_length) if args.network: alg_kwargs['network'] = args.network diff --git a/docs/viz/viz.md b/docs/viz/viz.md index a5d798b..d54ca37 100644 --- a/docs/viz/viz.md +++ b/docs/viz/viz.md @@ -12,7 +12,7 @@ Logging to /var/folders/mq/tgrn7bs17s1fnhlwt314b2fm0000gn/T/openai-2018-10-29-15 The location can be changed by changing `OPENAI_LOGDIR` environment variable; for instance: ```bash export OPENAI_LOGDIR=$HOME/logs/cartpole-ppo -python -m baselines.run --alg=ppo2 --env=CartPole-v0 --num_time steps=30000 --nsteps=128 +python -m baselines.run --alg=ppo2 --env=CartPole-v0 --num_timesteps=30000 --nsteps=128 ``` will log data to `~/logs/cartpole-ppo`. @@ -68,7 +68,7 @@ plt.plot(np.cumsum(r.monitor.l), pu.smooth(r.monitor.r, radius=10)) We can also get a similar curve by using logger summaries (instead of raw episode data in monitor.csv): ```python -plt.plot(r.progress.total_time steps, r.progress.eprewmean) +plt.plot(r.progress.total_timesteps, r.progress.eprewmean) ``` @@ -85,10 +85,10 @@ runs ppo2 with cartpole with 6 different seeds for 30k time steps, first with ba ```bash for seed in $(seq 0 5); do -OPENAI_LOGDIR=$HOME/logs/cartpole-ppo/b32-$seed python -m baselines.run --alg=ppo2 --env=CartPole-v0 --num_time steps=3e4 --seed=$seed --nsteps=32 +OPENAI_LOGDIR=$HOME/logs/cartpole-ppo/b32-$seed python -m baselines.run --alg=ppo2 --env=CartPole-v0 --num_timesteps=3e4 --seed=$seed --nsteps=32 done for seed in $(seq 0 5); do -OPENAI_LOGDIR=$HOME/logs/cartpole-ppo/b128-$seed python -m baselines.run --alg=ppo2 --env=CartPole-v0 --num_time steps=3e4 --seed=$seed --nsteps=128 +OPENAI_LOGDIR=$HOME/logs/cartpole-ppo/b128-$seed python -m baselines.run --alg=ppo2 --env=CartPole-v0 --num_timesteps=3e4 --seed=$seed --nsteps=128 done ``` These 12 runs can be loaded just as before: diff --git a/setup.py b/setup.py index f77faf0..2e5e36a 100644 --- a/setup.py +++ b/setup.py @@ -11,6 +11,7 @@ extras = { 'test': [ 'filelock', 'pytest', + 'pytest-forked', 'atari-py' ], 'bullet': [