Add video recorder (#666)
* Fix: Return the result of rendering from dummyvecenv * Add: Add a video recorder wrapper for vecenv * Change: Use VecVideoRecorder with --video_monitor flag * Change: Overwrite the metadata only when it isn't defined * Add: Define __del__ to make the file correctly closed in exit * Fix: Bump epidode_id in reset() * Fix: Use hasattr to check the existence of .metadata * Fix: Make directory when it doesn't exist * Change: Kepp recording for `video_length` steps, then close Because reset() is not what it is in normal gym.Env * Add: Enable to specify video_length from command line argument * Delete: Delete default value, None, of video_callable * Change: Use self.recorded_frames and self.recording to manage intervals * Add: Log the status of video recording * Fix: Fix saving path * Change: Place metadata in the base VecEnv * Delete: Delete unused imports * Fix: epidode_id => step_id * Fix: Refine the flag name * Change: Unify the flag name folloing to previous change * [WIP] Add: Add a test of VecVideoRecorder * Fix: Use PongNoFrameskip-v0 because SimpleEnv doesn't have render() * Change; Use TemporaryDirectory * Fix: minimal successful test * Add: Test against parallel environments * Add: Test against different type of VecEnvs * Change: Test against different length and interval of video capture * Delete: Reduce the number of tests * Change: Test if the output video is not empty * Add: Add some comments * Fix: Fix the flag name * Add: Add docstrings * Fix: Install ffmpeg in testing container for VecVideoRecorder's test * Fix: Delete unused things * Fix: Replace `video_callable` with `record_video_trigger` * Fix: Improve the explanation of `record_video_trigger` argument * Fix: Close owning vecenv in VecVideoRecorder.close to resolve memory leak
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
FROM python:3.6
|
||||
|
||||
RUN apt-get -y update && apt-get -y install ffmpeg
|
||||
# RUN apt-get -y update && apt-get -y install git wget python-dev python3-dev libopenmpi-dev python-pip zlib1g-dev cmake python-opencv
|
||||
|
||||
ENV CODE_DIR /root/code
|
||||
|
||||
COPY . $CODE_DIR/baselines
|
||||
|
@@ -131,6 +131,8 @@ def common_arg_parser():
|
||||
parser.add_argument('--num_env', help='Number of environment copies being run in parallel. When not specified, set to number of cpus for Atari, and to 1 for Mujoco', default=None, type=int)
|
||||
parser.add_argument('--reward_scale', help='Reward scale factor. Default: 1.0', default=1.0, type=float)
|
||||
parser.add_argument('--save_path', help='Path to save trained model to', default=None, type=str)
|
||||
parser.add_argument('--save_video_interval', help='Save video every x steps (0 = disabled)', default=0, type=int)
|
||||
parser.add_argument('--save_video_length', help='Length of recorded video. Default: 200', default=200, type=int)
|
||||
parser.add_argument('--play', default=False, action='store_true')
|
||||
return parser
|
||||
|
||||
|
@@ -32,6 +32,11 @@ class VecEnv(ABC):
|
||||
"""
|
||||
closed = False
|
||||
viewer = None
|
||||
|
||||
metadata = {
|
||||
'render.modes': ['human', 'rgb_array']
|
||||
}
|
||||
|
||||
def __init__(self, num_envs, observation_space, action_space):
|
||||
self.num_envs = num_envs
|
||||
self.observation_space = observation_space
|
||||
|
@@ -79,6 +79,6 @@ class DummyVecEnv(VecEnv):
|
||||
|
||||
def render(self, mode='human'):
|
||||
if self.num_envs == 1:
|
||||
self.envs[0].render(mode=mode)
|
||||
return self.envs[0].render(mode=mode)
|
||||
else:
|
||||
super().render(mode=mode)
|
||||
return super().render(mode=mode)
|
||||
|
49
baselines/common/vec_env/test_video_recorder.py
Normal file
49
baselines/common/vec_env/test_video_recorder.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Tests for asynchronous vectorized environments.
|
||||
"""
|
||||
|
||||
import gym
|
||||
import pytest
|
||||
import os
|
||||
import glob
|
||||
import tempfile
|
||||
|
||||
from .dummy_vec_env import DummyVecEnv
|
||||
from .shmem_vec_env import ShmemVecEnv
|
||||
from .subproc_vec_env import SubprocVecEnv
|
||||
from .vec_video_recorder import VecVideoRecorder
|
||||
|
||||
@pytest.mark.parametrize('klass', (DummyVecEnv, ShmemVecEnv, SubprocVecEnv))
|
||||
@pytest.mark.parametrize('num_envs', (1, 16))
|
||||
@pytest.mark.parametrize('video_length', (10, 500))
|
||||
@pytest.mark.parametrize('video_interval', (1, 50))
|
||||
def test_video_recorder(klass, num_envs, video_length, video_interval):
|
||||
"""
|
||||
Wrap an existing VecEnv with VevVideoRecorder,
|
||||
Make (video_interval + video_length + 1) steps,
|
||||
then check that the file is present
|
||||
"""
|
||||
|
||||
def make_fn():
|
||||
env = gym.make('PongNoFrameskip-v4')
|
||||
return env
|
||||
fns = [make_fn for _ in range(num_envs)]
|
||||
env = klass(fns)
|
||||
|
||||
with tempfile.TemporaryDirectory() as video_path:
|
||||
env = VecVideoRecorder(env, video_path, record_video_trigger=lambda x: x % video_interval == 0, video_length=video_length)
|
||||
|
||||
env.reset()
|
||||
for _ in range(video_interval + video_length + 1):
|
||||
env.step([0] * num_envs)
|
||||
env.close()
|
||||
|
||||
|
||||
recorded_video = glob.glob(os.path.join(video_path, "*.mp4"))
|
||||
|
||||
# first and second step
|
||||
assert len(recorded_video) == 2
|
||||
# Files are not empty
|
||||
assert all(os.stat(p).st_size != 0 for p in recorded_video)
|
||||
|
||||
|
89
baselines/common/vec_env/vec_video_recorder.py
Normal file
89
baselines/common/vec_env/vec_video_recorder.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import os
|
||||
from baselines import logger
|
||||
from baselines.common.vec_env import VecEnvWrapper
|
||||
from gym.wrappers.monitoring import video_recorder
|
||||
|
||||
|
||||
class VecVideoRecorder(VecEnvWrapper):
|
||||
"""
|
||||
Wrap VecEnv to record rendered image as mp4 video.
|
||||
"""
|
||||
|
||||
def __init__(self, venv, directory, record_video_trigger, video_length=200):
|
||||
"""
|
||||
# Arguments
|
||||
venv: VecEnv to wrap
|
||||
directory: Where to save videos
|
||||
record_video_trigger:
|
||||
Function that defines when to start recording.
|
||||
The function takes the current number of step,
|
||||
and returns whether we should start recording or not.
|
||||
video_length: Length of recorded video
|
||||
"""
|
||||
|
||||
VecEnvWrapper.__init__(self, venv)
|
||||
self.record_video_trigger = record_video_trigger
|
||||
self.video_recorder = None
|
||||
|
||||
self.directory = os.path.abspath(directory)
|
||||
if not os.path.exists(self.directory): os.mkdir(self.directory)
|
||||
|
||||
self.file_prefix = "vecenv"
|
||||
self.file_infix = '{}'.format(os.getpid())
|
||||
self.step_id = 0
|
||||
self.video_length = video_length
|
||||
|
||||
self.recording = False
|
||||
self.recorded_frames = 0
|
||||
|
||||
def reset(self):
|
||||
obs = self.venv.reset()
|
||||
|
||||
self.start_video_recorder()
|
||||
|
||||
return obs
|
||||
|
||||
def start_video_recorder(self):
|
||||
self.close_video_recorder()
|
||||
|
||||
base_path = os.path.join(self.directory, '{}.video.{}.video{:06}'.format(self.file_prefix, self.file_infix, self.step_id))
|
||||
self.video_recorder = video_recorder.VideoRecorder(
|
||||
env=self.venv,
|
||||
base_path=base_path,
|
||||
metadata={'step_id': self.step_id}
|
||||
)
|
||||
|
||||
self.video_recorder.capture_frame()
|
||||
self.recorded_frames = 1
|
||||
self.recording = True
|
||||
|
||||
def _video_enabled(self):
|
||||
return self.record_video_trigger(self.step_id)
|
||||
|
||||
def step_wait(self):
|
||||
obs, rews, dones, infos = self.venv.step_wait()
|
||||
|
||||
self.step_id += 1
|
||||
if self.recording:
|
||||
self.video_recorder.capture_frame()
|
||||
self.recorded_frames += 1
|
||||
if self.recorded_frames > self.video_length:
|
||||
logger.info("Saving video to ", self.video_recorder.path)
|
||||
self.close_video_recorder()
|
||||
elif self._video_enabled():
|
||||
self.start_video_recorder()
|
||||
|
||||
return obs, rews, dones, infos
|
||||
|
||||
def close_video_recorder(self):
|
||||
if self.recording:
|
||||
self.video_recorder.close()
|
||||
self.recording = False
|
||||
self.recorded_frames = 0
|
||||
|
||||
def close(self):
|
||||
VecEnvWrapper.close(self)
|
||||
self.close_video_recorder()
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
@@ -6,6 +6,7 @@ from collections import defaultdict
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
|
||||
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
||||
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
|
||||
from baselines.common.tf_util import get_session
|
||||
@@ -62,6 +63,8 @@ def train(args, extra_args):
|
||||
alg_kwargs.update(extra_args)
|
||||
|
||||
env = build_env(args)
|
||||
if args.save_video_interval != 0:
|
||||
env = VecVideoRecorder(env, osp.join(logger.Logger.CURRENT.dir, "videos"), record_video_trigger=lambda x: x % args.save_video_interval == 0, video_length=args.save_video_length)
|
||||
|
||||
if args.network:
|
||||
alg_kwargs['network'] = args.network
|
||||
|
Reference in New Issue
Block a user