mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-22 07:02:19 +00:00
ENH: implement RecordVideoV0 (#246)
This commit is contained in:
@@ -6,12 +6,14 @@
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any, SupportsFloat
|
||||
from typing import Any, Callable, List, SupportsFloat
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import error, logger
|
||||
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
|
||||
@@ -79,9 +81,200 @@ class RenderCollectionV0(gym.Wrapper):
|
||||
|
||||
|
||||
class RecordVideoV0(gym.Wrapper):
|
||||
"""Record a video of an environment."""
|
||||
"""This wrapper records videos of rollouts.
|
||||
|
||||
pass
|
||||
Usually, you only want to record episodes intermittently, say every hundredth episode.
|
||||
To do this, you can specify ``episode_trigger`` or ``step_trigger``.
|
||||
They should be functions returning a boolean that indicates whether a recording should be started at the
|
||||
current episode or step, respectively.
|
||||
If neither :attr:`episode_trigger` nor ``step_trigger`` is passed, a default ``episode_trigger`` will be employed,
|
||||
i.e. capped_cubic_video_schedule. This function starts a video at every episode that is a power of 3 until 1000 and
|
||||
then every 1000 episodes.
|
||||
By default, the recording will be stopped once reset is called. However, you can also create recordings of fixed
|
||||
length (possibly spanning several episodes) by passing a strictly positive value for ``video_length``.
|
||||
This wrapper uses the value `fps` from metadata as the number of frames per second;
|
||||
if `fps` is not defined in metadata, the default value 30 is used.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
video_folder: str,
|
||||
episode_trigger: Callable[[int], bool] = None,
|
||||
step_trigger: Callable[[int], bool] = None,
|
||||
video_length: int = 0,
|
||||
name_prefix: str = "rl-video",
|
||||
disable_logger: bool = False,
|
||||
):
|
||||
"""Wrapper records videos of rollouts.
|
||||
|
||||
Args:
|
||||
env: The environment that will be wrapped
|
||||
video_folder (str): The folder where the recordings will be stored
|
||||
episode_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this episode
|
||||
step_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this step
|
||||
video_length (int): The length of recorded episodes. If 0, entire episodes are recorded.
|
||||
Otherwise, snippets of the specified length are captured
|
||||
name_prefix (str): Will be prepended to the filename of the recordings
|
||||
disable_logger (bool): Whether to disable moviepy logger or not
|
||||
|
||||
"""
|
||||
super().__init__(env)
|
||||
try:
|
||||
import moviepy # noqa: F401
|
||||
except ImportError as e:
|
||||
raise error.DependencyNotInstalled(
|
||||
"MoviePy is not installed, run `pip install moviepy`"
|
||||
) from e
|
||||
|
||||
if env.render_mode in {None, "human", "ansi"}:
|
||||
raise ValueError(
|
||||
f"Render mode is {env.render_mode}, which is incompatible with RecordVideo.",
|
||||
"Initialize your environment with a render_mode that returns an image, such as rgb_array.",
|
||||
)
|
||||
|
||||
if episode_trigger is None and step_trigger is None:
|
||||
|
||||
def capped_cubic_video_schedule(episode_id: int) -> bool:
|
||||
if episode_id < 1000:
|
||||
return int(round(episode_id ** (1.0 / 3))) ** 3 == episode_id
|
||||
else:
|
||||
return episode_id % 1000 == 0
|
||||
|
||||
episode_trigger = capped_cubic_video_schedule
|
||||
|
||||
self.episode_trigger = episode_trigger
|
||||
self.step_trigger = step_trigger
|
||||
self.disable_logger = disable_logger
|
||||
|
||||
self.video_folder = os.path.abspath(video_folder)
|
||||
if os.path.isdir(self.video_folder):
|
||||
logger.warn(
|
||||
f"Overwriting existing videos at {self.video_folder} folder "
|
||||
f"(try specifying a different `video_folder` for the `RecordVideo` wrapper if this is not desired)"
|
||||
)
|
||||
os.makedirs(self.video_folder, exist_ok=True)
|
||||
|
||||
self.name_prefix = name_prefix
|
||||
self._video_name = None
|
||||
self.frames_per_sec = self.metadata.get("render_fps", 30)
|
||||
self.video_length = video_length if video_length != 0 else float("inf")
|
||||
self.recording = False
|
||||
self.recorded_frames = []
|
||||
self.render_history = []
|
||||
|
||||
self.step_id = -1
|
||||
self.episode_id = -1
|
||||
|
||||
def _capture_frame(self):
|
||||
assert self.recording, "Cannot capture a frame, recording wasn't started."
|
||||
|
||||
frame = self.env.render()
|
||||
if isinstance(frame, List):
|
||||
if len(frame) == 0: # render was called
|
||||
return
|
||||
self.render_history += frame
|
||||
frame = frame[-1]
|
||||
|
||||
if isinstance(frame, np.ndarray):
|
||||
self.recorded_frames.append(frame)
|
||||
else:
|
||||
self.stop_recording()
|
||||
logger.warn(
|
||||
"Recording stopped: expected type of frame returned by render ",
|
||||
f"to be a numpy array, got instead {type(frame)}.",
|
||||
)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[WrapperObsType, dict[str, Any]]:
|
||||
"""Reset the environment and eventually starts a new recording."""
|
||||
obs, info = super().reset(seed=seed, options=options)
|
||||
self.episode_id += 1
|
||||
|
||||
if self.recording and self.video_length == float("inf"):
|
||||
self.stop_recording()
|
||||
|
||||
if self.episode_trigger and self.episode_trigger(self.episode_id):
|
||||
self.start_recording(f"{self.name_prefix}-episode-{self.episode_id}")
|
||||
if self.recording:
|
||||
self._capture_frame()
|
||||
if len(self.recorded_frames) > self.video_length:
|
||||
self.stop_recording()
|
||||
|
||||
return obs, info
|
||||
|
||||
def step(
|
||||
self, action: WrapperActType
|
||||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]:
|
||||
"""Steps through the environment using action, recording observations if :attr:`self.recording`."""
|
||||
obs, rew, terminated, truncated, info = self.env.step(action)
|
||||
self.step_id += 1
|
||||
|
||||
if self.step_trigger and self.step_trigger(self.step_id):
|
||||
self.start_recording(f"{self.name_prefix}-step-{self.step_id}")
|
||||
if self.recording:
|
||||
self._capture_frame()
|
||||
|
||||
if len(self.recorded_frames) > self.video_length:
|
||||
self.stop_recording()
|
||||
|
||||
return obs, rew, terminated, truncated, info
|
||||
|
||||
def start_recording(self, video_name):
|
||||
"""Start a new recording. If it is already recording, stops the current recording before starting the new one."""
|
||||
if self.recording:
|
||||
self.stop_recording()
|
||||
|
||||
self.recording = True
|
||||
self._video_name = video_name
|
||||
|
||||
def stop_recording(self):
|
||||
"""Stop current recording and saves the video."""
|
||||
assert self.recording, "stop_recording was called, but no recording was started"
|
||||
|
||||
if len(self.recorded_frames) == 0:
|
||||
logger.warn("Ignored saving a video as there were zero frames to save.")
|
||||
else:
|
||||
try:
|
||||
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
|
||||
except ImportError as e:
|
||||
raise error.DependencyNotInstalled(
|
||||
"MoviePy is not installed, run `pip install moviepy`"
|
||||
) from e
|
||||
|
||||
clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec)
|
||||
moviepy_logger = None if self.disable_logger else "bar"
|
||||
path = os.path.join(self.video_folder, f"{self._video_name}.mp4")
|
||||
clip.write_videofile(path, logger=moviepy_logger)
|
||||
|
||||
self.recorded_frames = []
|
||||
self.recording = False
|
||||
self._video_name = None
|
||||
|
||||
def render(self):
|
||||
"""Compute the render frames as specified by render_mode attribute during initialization of the environment."""
|
||||
render_out = super().render()
|
||||
if self.recording and isinstance(render_out, List):
|
||||
self.recorded_frames += render_out
|
||||
|
||||
if len(self.render_history) > 0:
|
||||
tmp_history = self.render_history
|
||||
self.render_history = []
|
||||
return tmp_history + render_out
|
||||
else:
|
||||
return render_out
|
||||
|
||||
def close(self):
|
||||
"""Closes the wrapper then the video recorder."""
|
||||
super().close()
|
||||
if self.recording:
|
||||
self.stop_recording()
|
||||
|
||||
def __del__(self):
|
||||
"""Warn the user in case last video wasn't saved."""
|
||||
if len(self.recorded_frames) > 0:
|
||||
logger.warn("Unable to save last video! Did you call close()?")
|
||||
|
||||
|
||||
class HumanRenderingV0(gym.Wrapper):
|
||||
|
@@ -10,7 +10,7 @@ from gymnasium.wrappers.monitoring import video_recorder
|
||||
def capped_cubic_video_schedule(episode_id: int) -> bool:
|
||||
"""The default episode trigger.
|
||||
|
||||
This function will trigger recordings at the episode indices 0, 1, 4, 8, 27, ..., :math:`k^3`, ..., 729, 1000, 2000, 3000, ...
|
||||
This function will trigger recordings at the episode indices 0, 1, 8, 27, ..., :math:`k^3`, ..., 729, 1000, 2000, 3000, ...
|
||||
|
||||
Args:
|
||||
episode_id: The episode number
|
||||
|
@@ -1 +1,166 @@
|
||||
"""Test suite for RecordVideoV0."""
|
||||
import os
|
||||
import shutil
|
||||
from typing import List
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.experimental.wrappers import RecordVideoV0
|
||||
|
||||
|
||||
def test_record_video_using_default_trigger():
|
||||
"""Test RecordVideo using the default episode trigger."""
|
||||
env = gym.make("CartPole-v1", render_mode="rgb_array_list")
|
||||
env = RecordVideoV0(env, "videos")
|
||||
env.reset()
|
||||
episode_count = 0
|
||||
for _ in range(199):
|
||||
action = env.action_space.sample()
|
||||
_, _, terminated, truncated, _ = env.step(action)
|
||||
if terminated or truncated:
|
||||
env.reset()
|
||||
episode_count += 1
|
||||
|
||||
env.close()
|
||||
assert os.path.isdir("videos")
|
||||
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
|
||||
assert len(mp4_files) == sum(
|
||||
env.episode_trigger(i) for i in range(episode_count + 1)
|
||||
)
|
||||
shutil.rmtree("videos")
|
||||
|
||||
|
||||
def test_record_video_while_rendering():
|
||||
"""Test RecordVideo while calling render and using a _list render mode."""
|
||||
env = gym.make("FrozenLake-v1", render_mode="rgb_array_list")
|
||||
env = RecordVideoV0(env, "videos")
|
||||
env.reset()
|
||||
episode_count = 0
|
||||
for _ in range(199):
|
||||
action = env.action_space.sample()
|
||||
_, _, terminated, truncated, _ = env.step(action)
|
||||
env.render()
|
||||
if terminated or truncated:
|
||||
env.reset()
|
||||
episode_count += 1
|
||||
|
||||
env.close()
|
||||
assert os.path.isdir("videos")
|
||||
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
|
||||
assert len(mp4_files) == sum(
|
||||
env.episode_trigger(i) for i in range(episode_count + 1)
|
||||
)
|
||||
shutil.rmtree("videos")
|
||||
|
||||
|
||||
def test_record_video_step_trigger():
|
||||
"""Test RecordVideo defining step trigger function."""
|
||||
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
|
||||
env._max_episode_steps = 20
|
||||
env = RecordVideoV0(env, "videos", step_trigger=lambda x: x % 100 == 0)
|
||||
env.reset()
|
||||
for _ in range(199):
|
||||
action = env.action_space.sample()
|
||||
_, _, terminated, truncated, _ = env.step(action)
|
||||
if terminated or truncated:
|
||||
env.reset()
|
||||
env.close()
|
||||
assert os.path.isdir("videos")
|
||||
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
|
||||
shutil.rmtree("videos")
|
||||
assert len(mp4_files) == 2
|
||||
|
||||
|
||||
def test_record_video_both_trigger():
|
||||
"""Test RecordVideo defining both step and episode trigger functions."""
|
||||
env = gym.make(
|
||||
"CartPole-v1", render_mode="rgb_array_list", disable_env_checker=True
|
||||
)
|
||||
env._max_episode_steps = 20
|
||||
env = RecordVideoV0(
|
||||
env,
|
||||
"videos",
|
||||
step_trigger=lambda x: x == 100,
|
||||
episode_trigger=lambda x: x == 0 or x == 3,
|
||||
)
|
||||
env.reset()
|
||||
for _ in range(199):
|
||||
action = env.action_space.sample()
|
||||
_, _, terminated, truncated, _ = env.step(action)
|
||||
if terminated or truncated:
|
||||
env.reset()
|
||||
env.close()
|
||||
assert os.path.isdir("videos")
|
||||
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
|
||||
shutil.rmtree("videos")
|
||||
assert len(mp4_files) == 3
|
||||
|
||||
|
||||
def test_record_video_length():
|
||||
"""Test if argument video_length of RecordVideo works properly."""
|
||||
env = gym.make("CartPole-v1", render_mode="rgb_array_list")
|
||||
env._max_episode_steps = 20
|
||||
env = RecordVideoV0(env, "videos", step_trigger=lambda x: x == 0, video_length=10)
|
||||
env.reset()
|
||||
for _ in range(10):
|
||||
action = env.action_space.sample()
|
||||
env.step(action)
|
||||
|
||||
assert env.recording
|
||||
action = env.action_space.sample()
|
||||
env.step(action)
|
||||
assert not env.recording
|
||||
env.close()
|
||||
assert os.path.isdir("videos")
|
||||
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
|
||||
assert len(mp4_files) == 1
|
||||
shutil.rmtree("videos")
|
||||
|
||||
|
||||
def test_rendering_works():
|
||||
"""Test if render output is as expected when the env is wrapped with RecordVideo."""
|
||||
env = gym.make("CartPole-v1", render_mode="rgb_array_list")
|
||||
env._max_episode_steps = 20
|
||||
env = RecordVideoV0(env, "videos")
|
||||
env.reset()
|
||||
n_steps = 10
|
||||
for _ in range(n_steps):
|
||||
action = env.action_space.sample()
|
||||
env.step(action)
|
||||
|
||||
render_out = env.render()
|
||||
assert isinstance(render_out, List)
|
||||
assert len(render_out) == n_steps + 1
|
||||
render_out = env.render()
|
||||
assert isinstance(render_out, List)
|
||||
assert len(render_out) == 0
|
||||
env.close()
|
||||
shutil.rmtree("videos")
|
||||
|
||||
|
||||
def make_env(gym_id, idx, **kwargs):
|
||||
"""Utility function to make an env and wrap it with RecordVideo only the first time."""
|
||||
|
||||
def thunk():
|
||||
env = gym.make(gym_id, disable_env_checker=True, **kwargs)
|
||||
env._max_episode_steps = 20
|
||||
if idx == 0:
|
||||
env = RecordVideoV0(env, "videos", step_trigger=lambda x: x % 100 == 0)
|
||||
return env
|
||||
|
||||
return thunk
|
||||
|
||||
|
||||
def test_record_video_within_vector():
|
||||
"""Test RecordVideo used as env of SyncVectorEnv."""
|
||||
envs = gym.vector.SyncVectorEnv(
|
||||
[make_env("CartPole-v1", i, render_mode="rgb_array") for i in range(2)]
|
||||
)
|
||||
envs = gym.wrappers.RecordEpisodeStatistics(envs)
|
||||
envs.reset()
|
||||
for i in range(199):
|
||||
_, _, _, _, infos = envs.step(envs.action_space.sample())
|
||||
|
||||
assert os.path.isdir("videos")
|
||||
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
|
||||
assert len(mp4_files) == 2
|
||||
shutil.rmtree("videos")
|
||||
|
Reference in New Issue
Block a user