Files
Gymnasium/tests/wrappers/test_record_video.py
2025-06-07 17:57:58 +01:00

186 lines
5.8 KiB
Python

"""Test suite for RecordVideo wrapper."""
import os
import shutil
import numpy as np
import pytest
import gymnasium as gym
from gymnasium.wrappers import RecordVideo, RenderCollection
def test_video_folder_and_filenames(
video_folder="custom_video_folder", name_prefix="video-prefix"
):
env = gym.make("CartPole-v1", render_mode="rgb_array")
env = RecordVideo(
env,
video_folder=video_folder,
name_prefix=name_prefix,
episode_trigger=lambda x: x in [1, 4],
step_trigger=lambda x: x in [0, 25],
)
env.reset(seed=123)
env.action_space.seed(123)
for _ in range(100):
action = env.action_space.sample()
_, _, terminated, truncated, _ = env.step(action)
if terminated or truncated:
env.reset()
env.close()
assert os.path.isdir(video_folder)
mp4_files = {file for file in os.listdir(video_folder) if file.endswith(".mp4")}
shutil.rmtree(video_folder)
assert mp4_files == {
"video-prefix-step-0.mp4", # step triggers
"video-prefix-step-25.mp4",
"video-prefix-episode-1.mp4", # episode triggers
"video-prefix-episode-4.mp4",
}
@pytest.mark.parametrize("episodic_trigger", [None, lambda x: x in [0, 3, 5, 10, 12]])
def test_episodic_trigger(episodic_trigger):
"""Test RecordVideo using the default episode trigger."""
env = gym.make("CartPole-v1", render_mode="rgb_array")
env = RecordVideo(env, "videos", episode_trigger=episodic_trigger)
env.reset()
episode_count = 0
for _ in range(199):
action = env.action_space.sample()
_, _, terminated, truncated, _ = env.step(action)
if terminated or truncated:
env.reset()
episode_count += 1
env.close()
assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
assert env.episode_trigger is not None
assert len(mp4_files) == sum(
env.episode_trigger(i) for i in range(episode_count + 1)
)
shutil.rmtree("videos")
def test_step_trigger():
"""Test RecordVideo defining step trigger function."""
env = gym.make("CartPole-v1", render_mode="rgb_array")
env = RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
env.reset()
for _ in range(199):
action = env.action_space.sample()
_, _, terminated, truncated, _ = env.step(action)
if terminated or truncated:
env.reset()
env.close()
assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
shutil.rmtree("videos")
assert len(mp4_files) == 2
def test_both_episodic_and_step_trigger():
"""Test RecordVideo defining both step and episode trigger functions."""
env = gym.make("CartPole-v1", render_mode="rgb_array")
env = RecordVideo(
env,
"videos",
step_trigger=lambda x: x == 100,
episode_trigger=lambda x: x == 0 or x == 3,
)
# episode reset time steps: 0, 18, 44, 55, 80, 103, 117, 143, 173, 191
# steps recorded: 0-18, 55-80, 100-103
env.reset(seed=123)
env.action_space.seed(123)
for i in range(199):
action = env.action_space.sample()
_, _, terminated, truncated, _ = env.step(action)
if terminated or truncated:
env.reset()
env.close()
assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
shutil.rmtree("videos")
assert len(mp4_files) == 3
def test_video_length(video_length: int = 10):
"""Test if argument video_length of RecordVideo works properly."""
env = gym.make("CartPole-v1", render_mode="rgb_array")
env = RecordVideo(
env, "videos", step_trigger=lambda x: x == 0, video_length=video_length
)
env.reset(seed=123)
env.action_space.seed(123)
for _ in range(video_length):
_, _, term, trunc, _ = env.step(env.action_space.sample())
if term or trunc:
break
# check that the environment is still recording then take a step to take the number of steps > video length
assert env.recording
env.step(env.action_space.sample())
assert not env.recording
env.close()
# check that only one video is recorded
assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
assert len(mp4_files) == 1
shutil.rmtree("videos")
def test_with_rgb_array_list(n_steps: int = 10):
"""Test if `env.render` works with RenderCollection and RecordVideo."""
# fyi, can't work as a `pytest.mark.parameterize`
env = RecordVideo(
RenderCollection(gym.make("CartPole-v1", render_mode="rgb_array")), "videos"
)
env.reset(seed=123)
env.action_space.seed(123)
for _ in range(n_steps):
env.step(env.action_space.sample())
render_out = env.render()
assert isinstance(render_out, list)
assert len(render_out) == n_steps + 1
assert all(isinstance(render, np.ndarray) for render in render_out)
assert all(render.ndim == 3 for render in render_out)
render_out = env.render()
assert isinstance(render_out, list)
assert len(render_out) == 0
env.close()
shutil.rmtree("videos")
# Test in reverse order
env = RenderCollection(
RecordVideo(gym.make("CartPole-v1", render_mode="rgb_array"), "videos")
)
env.reset(seed=123)
env.action_space.seed(123)
for _ in range(n_steps):
env.step(env.action_space.sample())
render_out = env.render()
assert isinstance(render_out, list)
assert len(render_out) == n_steps + 1
assert all(isinstance(render, np.ndarray) for render in render_out)
assert all(render.ndim == 3 for render in render_out)
render_out = env.render()
assert isinstance(render_out, list)
assert len(render_out) == 0
env.close()
shutil.rmtree("videos")