mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 14:10:30 +00:00
Add save_video util and deprecate RecordVideo in favor of it (#3016)
* init * add save_video util * simplify API @pseudo-rnd-thoughts * fix video_length and remove folder warning * remove RecordVideo deprecation warnings * add test record video * avoid test failing cascade
This commit is contained in:
109
gym/utils/save_video.py
Normal file
109
gym/utils/save_video.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Utility functions to save rendering videos."""
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
|
||||
import gym
|
||||
from gym import logger
|
||||
|
||||
try:
|
||||
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
|
||||
except ImportError:
|
||||
raise gym.error.DependencyNotInstalled(
|
||||
"MoviePy is not installed, run `pip install moviepy`"
|
||||
)
|
||||
|
||||
|
||||
def capped_cubic_video_schedule(episode_id: int) -> bool:
|
||||
"""The default episode trigger.
|
||||
|
||||
This function will trigger recordings at the episode indices 0, 1, 4, 8, 27, ..., :math:`k^3`, ..., 729, 1000, 2000, 3000, ...
|
||||
|
||||
Args:
|
||||
episode_id: The episode number
|
||||
|
||||
Returns:
|
||||
If to apply a video schedule number
|
||||
"""
|
||||
if episode_id < 1000:
|
||||
return int(round(episode_id ** (1.0 / 3))) ** 3 == episode_id
|
||||
else:
|
||||
return episode_id % 1000 == 0
|
||||
|
||||
|
||||
def save_video(
|
||||
frames: list,
|
||||
video_folder: str,
|
||||
episode_trigger: Callable[[int], bool] = None,
|
||||
step_trigger: Callable[[int], bool] = None,
|
||||
video_length: Optional[int] = None,
|
||||
name_prefix: str = "rl-video",
|
||||
episode_index: int = 0,
|
||||
step_starting_index: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
"""Save videos from rendering frames.
|
||||
|
||||
This function extract video from a list of render frame episodes.
|
||||
|
||||
Args:
|
||||
frames (List[RenderFrame]): A list of frames to compose the video.
|
||||
video_folder (str): The folder where the recordings will be stored
|
||||
episode_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this episode
|
||||
step_trigger: Function that accepts an integer and returns ``True`` iff a recording should be started at this step
|
||||
video_length (int): The length of recorded episodes. If it isn't specified, the entire episode is recorded.
|
||||
Otherwise, snippets of the specified length are captured.
|
||||
name_prefix (str): Will be prepended to the filename of the recordings.
|
||||
episode_index (int): The index of the current episode.
|
||||
step_starting_index (int): The step index of the first frame.
|
||||
**kwargs: The kwargs that will be passed to moviepy's ImageSequenceClip.
|
||||
You need to specify either fps or duration.
|
||||
|
||||
Example:
|
||||
>>> import gym
|
||||
>>> from gym.utils.save_video import save_video
|
||||
>>> env = gym.make("FrozenLake-v1", render_mode="rgb_array")
|
||||
>>> env.reset()
|
||||
>>> step_starting_index = 0
|
||||
>>> episode_index = 0
|
||||
>>> for step_index in range(199):
|
||||
... action = env.action_space.sample()
|
||||
... _, _, done, _ = env.step(action)
|
||||
... if done:
|
||||
... save_video(
|
||||
... env.render(),
|
||||
... "videos",
|
||||
... fps=env.metadata["render_fps"],
|
||||
... step_starting_index=step_starting_index,
|
||||
... episode_index=episode_index
|
||||
... )
|
||||
... step_starting_index = step_index + 1
|
||||
... episode_index += 1
|
||||
... env.reset()
|
||||
>>> env.close()
|
||||
"""
|
||||
if not isinstance(frames, list):
|
||||
logger.error(
|
||||
f"Expected a list of frames, got a {frames.__class__.__name__} instead."
|
||||
)
|
||||
if episode_trigger is None and step_trigger is None:
|
||||
episode_trigger = capped_cubic_video_schedule
|
||||
|
||||
video_folder = os.path.abspath(video_folder)
|
||||
os.makedirs(video_folder, exist_ok=True)
|
||||
path_prefix = f"{video_folder}/{name_prefix}"
|
||||
|
||||
if episode_trigger is not None and episode_trigger(episode_index):
|
||||
clip = ImageSequenceClip(frames[:video_length], **kwargs)
|
||||
clip.write_videofile(f"{path_prefix}-episode-{episode_index}.mp4")
|
||||
|
||||
if step_trigger is not None:
|
||||
# skip the first frame since it comes from reset
|
||||
for step_index, frame_index in enumerate(
|
||||
range(1, len(frames)), start=step_starting_index
|
||||
):
|
||||
if step_trigger(step_index):
|
||||
end_index = (
|
||||
frame_index + video_length if video_length is not None else None
|
||||
)
|
||||
clip = ImageSequenceClip(frames[frame_index:end_index], **kwargs)
|
||||
clip.write_videofile(f"{path_prefix}-step-{step_index}.mp4")
|
104
tests/utils/test_save_video.py
Normal file
104
tests/utils/test_save_video.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import gym
|
||||
from gym.utils.save_video import capped_cubic_video_schedule, save_video
|
||||
|
||||
|
||||
def test_record_video_using_default_trigger():
|
||||
env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
|
||||
|
||||
env.reset()
|
||||
step_starting_index = 0
|
||||
episode_index = 0
|
||||
for step_index in range(199):
|
||||
action = env.action_space.sample()
|
||||
_, _, done, _ = env.step(action)
|
||||
if done:
|
||||
save_video(
|
||||
env.render(),
|
||||
"videos",
|
||||
fps=env.metadata["render_fps"],
|
||||
step_starting_index=step_starting_index,
|
||||
episode_index=episode_index,
|
||||
)
|
||||
step_starting_index = step_index + 1
|
||||
episode_index += 1
|
||||
env.reset()
|
||||
|
||||
env.close()
|
||||
assert os.path.isdir("videos")
|
||||
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
|
||||
shutil.rmtree("videos")
|
||||
assert len(mp4_files) == sum(
|
||||
capped_cubic_video_schedule(i) for i in range(episode_index)
|
||||
)
|
||||
|
||||
|
||||
def modulo_step_trigger(mod: int):
|
||||
def step_trigger(step_index):
|
||||
return step_index % mod == 0
|
||||
|
||||
return step_trigger
|
||||
|
||||
|
||||
def test_record_video_step_trigger():
|
||||
env = gym.make("CartPole-v1", render_mode="rgb_array")
|
||||
env._max_episode_steps = 20
|
||||
|
||||
env.reset()
|
||||
step_starting_index = 0
|
||||
episode_index = 0
|
||||
for step_index in range(199):
|
||||
action = env.action_space.sample()
|
||||
_, _, done, _ = env.step(action)
|
||||
if done:
|
||||
save_video(
|
||||
env.render(),
|
||||
"videos",
|
||||
fps=env.metadata["render_fps"],
|
||||
step_trigger=modulo_step_trigger(100),
|
||||
step_starting_index=step_starting_index,
|
||||
episode_index=episode_index,
|
||||
)
|
||||
step_starting_index = step_index + 1
|
||||
episode_index += 1
|
||||
env.reset()
|
||||
env.close()
|
||||
|
||||
assert os.path.isdir("videos")
|
||||
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
|
||||
shutil.rmtree("videos")
|
||||
assert len(mp4_files) == 2
|
||||
|
||||
|
||||
def test_record_video_within_vector():
|
||||
envs = gym.vector.make(
|
||||
"CartPole-v1", num_envs=2, asynchronous=True, render_mode="rgb_array"
|
||||
)
|
||||
envs = gym.wrappers.RecordEpisodeStatistics(envs)
|
||||
envs.reset()
|
||||
episode_frames = []
|
||||
step_starting_index = 0
|
||||
episode_index = 0
|
||||
for step_index in range(199):
|
||||
_, _, _, infos = envs.step(envs.action_space.sample())
|
||||
episode_frames.extend(envs.call("render")[0])
|
||||
|
||||
if "episode" in infos and infos["_episode"][0]:
|
||||
save_video(
|
||||
episode_frames,
|
||||
"videos",
|
||||
fps=envs.metadata["render_fps"],
|
||||
step_trigger=modulo_step_trigger(100),
|
||||
step_starting_index=step_starting_index,
|
||||
episode_index=episode_index,
|
||||
)
|
||||
episode_frames = []
|
||||
step_starting_index = step_index + 1
|
||||
episode_index += 1
|
||||
|
||||
assert os.path.isdir("videos")
|
||||
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
|
||||
shutil.rmtree("videos")
|
||||
assert len(mp4_files) == 2
|
Reference in New Issue
Block a user