Add RecordVideo wrapper (#2300)

* Add RecordVideo wrapper

* bug fix

* don't change gym's core API

* add test cases

* reformat
This commit is contained in:
Costa Huang
2021-08-18 16:36:40 -04:00
committed by GitHub
parent ee30486bfb
commit 2853ce4797
7 changed files with 140 additions and 7 deletions

View File

@@ -92,9 +92,10 @@ class AsyncVectorEnv(VectorEnv):
self.env_fns = env_fns
self.shared_memory = shared_memory
self.copy = copy
dummy_env = env_fns[0]()
self.metadata = dummy_env.metadata
if (observation_space is None) or (action_space is None):
dummy_env = env_fns[0]()
observation_space = observation_space or dummy_env.observation_space
action_space = action_space or dummy_env.action_space
dummy_env.close()

View File

@@ -33,6 +33,7 @@ class SyncVectorEnv(VectorEnv):
self.env_fns = env_fns
self.envs = [env_fn() for env_fn in env_fns]
self.copy = copy
self.metadata = self.envs[0].metadata
if (observation_space is None) or (action_space is None):
observation_space = observation_space or self.envs[0].observation_space

View File

@@ -32,6 +32,7 @@ class VectorEnv(gym.Env):
def __init__(self, num_envs, observation_space, action_space):
super(VectorEnv, self).__init__()
self.num_envs = num_envs
self.is_vector_env = True
self.observation_space = batch_space(observation_space, n=num_envs)
self.action_space = Tuple((action_space,) * num_envs)

View File

@@ -14,3 +14,4 @@ from gym.wrappers.transform_reward import TransformReward
from gym.wrappers.resize_observation import ResizeObservation
from gym.wrappers.clip_action import ClipAction
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
from gym.wrappers.record_video import RecordVideo

View File

@@ -7,7 +7,6 @@ import gym
class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env, deque_size=100):
super(RecordEpisodeStatistics, self).__init__(env)
self.env_is_vec = isinstance(env, gym.vector.VectorEnv)
self.num_envs = getattr(env, "num_envs", 1)
self.t0 = (
time.time()
@@ -17,6 +16,7 @@ class RecordEpisodeStatistics(gym.Wrapper):
self.episode_lengths = None
self.return_queue = deque(maxlen=deque_size)
self.length_queue = deque(maxlen=deque_size)
self.is_vector_env = getattr(env, "is_vector_env", False)
def reset(self, **kwargs):
observations = super(RecordEpisodeStatistics, self).reset(**kwargs)
@@ -30,7 +30,7 @@ class RecordEpisodeStatistics(gym.Wrapper):
)
self.episode_returns += rewards
self.episode_lengths += 1
if not self.env_is_vec:
if not self.is_vector_env:
infos = [infos]
dones = [dones]
for i in range(len(dones)):
@@ -52,6 +52,6 @@ class RecordEpisodeStatistics(gym.Wrapper):
return (
observations,
rewards,
dones if self.env_is_vec else dones[0],
infos if self.env_is_vec else infos[0],
dones if self.is_vector_env else dones[0],
infos if self.is_vector_env else infos[0],
)

View File

@@ -0,0 +1,79 @@
import os
import gym
from typing import Callable
from gym.wrappers.monitoring import video_recorder
class RecordVideo(gym.Wrapper):
def __init__(
self,
env,
video_folder: str,
record_video_trigger: Callable[[int], bool],
video_length: int = 0,
name_prefix: str = "rl-video",
):
super(RecordVideo, self).__init__(env)
self.record_video_trigger = record_video_trigger
self.video_recorder = None
self.video_folder = os.path.abspath(video_folder)
# Create output folder if needed
os.makedirs(self.video_folder, exist_ok=True)
self.name_prefix = name_prefix
self.step_id = 0
self.video_length = video_length
self.recording = False
self.recorded_frames = 0
self.is_vector_env = getattr(env, "is_vector_env", False)
def reset(self, **kwargs):
observations = super(RecordVideo, self).reset(**kwargs)
self.start_video_recorder()
return observations
def start_video_recorder(self):
self.close_video_recorder()
video_name = f"{self.name_prefix}-step-{self.step_id}"
base_path = os.path.join(self.video_folder, video_name)
self.video_recorder = video_recorder.VideoRecorder(
env=self.env, base_path=base_path, metadata={"step_id": self.step_id}
)
self.video_recorder.capture_frame()
self.recorded_frames = 1
self.recording = True
def _video_enabled(self):
return self.record_video_trigger(self.step_id)
def step(self, action):
observations, rewards, dones, infos = super(RecordVideo, self).step(action)
self.step_id += 1
if self.recording:
self.video_recorder.capture_frame()
self.recorded_frames += 1
if self.video_length > 0:
if self.recorded_frames > self.video_length:
self.close_video_recorder()
else:
if not self.is_vector_env:
dones = [dones]
if dones[0]:
self.close_video_recorder()
elif self._video_enabled():
self.start_video_recorder()
return observations, rewards, dones if self.is_vector_env else dones[0], infos
def close_video_recorder(self) -> None:
if self.recording:
self.video_recorder.close()
self.recording = False
self.recorded_frames = 1

View File

@@ -0,0 +1,50 @@
import pytest
import os
import shutil
import gym
from gym.wrappers import RecordEpisodeStatistics, RecordVideo
def test_record_video():
env = gym.make("CartPole-v1")
env = gym.wrappers.RecordVideo(
env, "videos", record_video_trigger=lambda x: x % 100 == 0
)
env.reset()
for _ in range(200):
action = env.action_space.sample()
_, _, done, _ = env.step(action)
if done:
env.close()
break
assert os.path.isdir("videos")
shutil.rmtree("videos")
def make_env(gym_id, seed):
def thunk():
env = gym.make(gym_id)
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
env = gym.wrappers.RecordVideo(
env, "videos", record_video_trigger=lambda x: x % 100 == 0
)
return env
return thunk
def test_record_video_vector():
envs = gym.vector.SyncVectorEnv([make_env("CartPole-v1", 1 + i) for i in range(2)])
envs = gym.wrappers.RecordEpisodeStatistics(envs)
envs.reset()
for i in range(100):
_, _, _, infos = envs.step(envs.action_space.sample())
for info in infos:
if "episode" in info.keys():
print(f"i, episode_reward={info['episode']['r']}")
break
assert os.path.isdir("videos")
shutil.rmtree("videos")