mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 22:11:25 +00:00
186 lines
5.8 KiB
Python
186 lines
5.8 KiB
Python
"""Test suite for RecordVideo wrapper."""
|
|
|
|
import os
|
|
import shutil
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
import gymnasium as gym
|
|
from gymnasium.wrappers import RecordVideo, RenderCollection
|
|
|
|
|
|
def test_video_folder_and_filenames(
|
|
video_folder="custom_video_folder", name_prefix="video-prefix"
|
|
):
|
|
env = gym.make("CartPole-v1", render_mode="rgb_array")
|
|
env = RecordVideo(
|
|
env,
|
|
video_folder=video_folder,
|
|
name_prefix=name_prefix,
|
|
episode_trigger=lambda x: x in [1, 4],
|
|
step_trigger=lambda x: x in [0, 25],
|
|
)
|
|
|
|
env.reset(seed=123)
|
|
env.action_space.seed(123)
|
|
for _ in range(100):
|
|
action = env.action_space.sample()
|
|
_, _, terminated, truncated, _ = env.step(action)
|
|
if terminated or truncated:
|
|
env.reset()
|
|
env.close()
|
|
|
|
assert os.path.isdir(video_folder)
|
|
mp4_files = {file for file in os.listdir(video_folder) if file.endswith(".mp4")}
|
|
shutil.rmtree(video_folder)
|
|
assert mp4_files == {
|
|
"video-prefix-step-0.mp4", # step triggers
|
|
"video-prefix-step-25.mp4",
|
|
"video-prefix-episode-1.mp4", # episode triggers
|
|
"video-prefix-episode-4.mp4",
|
|
}
|
|
|
|
|
|
@pytest.mark.parametrize("episodic_trigger", [None, lambda x: x in [0, 3, 5, 10, 12]])
|
|
def test_episodic_trigger(episodic_trigger):
|
|
"""Test RecordVideo using the default episode trigger."""
|
|
env = gym.make("CartPole-v1", render_mode="rgb_array")
|
|
env = RecordVideo(env, "videos", episode_trigger=episodic_trigger)
|
|
|
|
env.reset()
|
|
episode_count = 0
|
|
for _ in range(199):
|
|
action = env.action_space.sample()
|
|
_, _, terminated, truncated, _ = env.step(action)
|
|
if terminated or truncated:
|
|
env.reset()
|
|
episode_count += 1
|
|
env.close()
|
|
|
|
assert os.path.isdir("videos")
|
|
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
|
|
assert env.episode_trigger is not None
|
|
assert len(mp4_files) == sum(
|
|
env.episode_trigger(i) for i in range(episode_count + 1)
|
|
)
|
|
shutil.rmtree("videos")
|
|
|
|
|
|
def test_step_trigger():
|
|
"""Test RecordVideo defining step trigger function."""
|
|
env = gym.make("CartPole-v1", render_mode="rgb_array")
|
|
env = RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
|
|
env.reset()
|
|
for _ in range(199):
|
|
action = env.action_space.sample()
|
|
_, _, terminated, truncated, _ = env.step(action)
|
|
if terminated or truncated:
|
|
env.reset()
|
|
env.close()
|
|
assert os.path.isdir("videos")
|
|
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
|
|
shutil.rmtree("videos")
|
|
assert len(mp4_files) == 2
|
|
|
|
|
|
def test_both_episodic_and_step_trigger():
|
|
"""Test RecordVideo defining both step and episode trigger functions."""
|
|
env = gym.make("CartPole-v1", render_mode="rgb_array")
|
|
env = RecordVideo(
|
|
env,
|
|
"videos",
|
|
step_trigger=lambda x: x == 100,
|
|
episode_trigger=lambda x: x == 0 or x == 3,
|
|
)
|
|
# episode reset time steps: 0, 18, 44, 55, 80, 103, 117, 143, 173, 191
|
|
# steps recorded: 0-18, 55-80, 100-103
|
|
|
|
env.reset(seed=123)
|
|
env.action_space.seed(123)
|
|
for i in range(199):
|
|
action = env.action_space.sample()
|
|
_, _, terminated, truncated, _ = env.step(action)
|
|
if terminated or truncated:
|
|
env.reset()
|
|
env.close()
|
|
|
|
assert os.path.isdir("videos")
|
|
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
|
|
shutil.rmtree("videos")
|
|
assert len(mp4_files) == 3
|
|
|
|
|
|
def test_video_length(video_length: int = 10):
|
|
"""Test if argument video_length of RecordVideo works properly."""
|
|
env = gym.make("CartPole-v1", render_mode="rgb_array")
|
|
env = RecordVideo(
|
|
env, "videos", step_trigger=lambda x: x == 0, video_length=video_length
|
|
)
|
|
|
|
env.reset(seed=123)
|
|
env.action_space.seed(123)
|
|
for _ in range(video_length):
|
|
_, _, term, trunc, _ = env.step(env.action_space.sample())
|
|
if term or trunc:
|
|
break
|
|
|
|
# check that the environment is still recording then take a step to take the number of steps > video length
|
|
assert env.recording
|
|
env.step(env.action_space.sample())
|
|
assert not env.recording
|
|
env.close()
|
|
|
|
# check that only one video is recorded
|
|
assert os.path.isdir("videos")
|
|
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
|
|
assert len(mp4_files) == 1
|
|
shutil.rmtree("videos")
|
|
|
|
|
|
def test_with_rgb_array_list(n_steps: int = 10):
|
|
"""Test if `env.render` works with RenderCollection and RecordVideo."""
|
|
# fyi, can't work as a `pytest.mark.parameterize`
|
|
env = RecordVideo(
|
|
RenderCollection(gym.make("CartPole-v1", render_mode="rgb_array")), "videos"
|
|
)
|
|
env.reset(seed=123)
|
|
env.action_space.seed(123)
|
|
for _ in range(n_steps):
|
|
env.step(env.action_space.sample())
|
|
|
|
render_out = env.render()
|
|
assert isinstance(render_out, list)
|
|
assert len(render_out) == n_steps + 1
|
|
assert all(isinstance(render, np.ndarray) for render in render_out)
|
|
assert all(render.ndim == 3 for render in render_out)
|
|
|
|
render_out = env.render()
|
|
assert isinstance(render_out, list)
|
|
assert len(render_out) == 0
|
|
|
|
env.close()
|
|
shutil.rmtree("videos")
|
|
|
|
# Test in reverse order
|
|
env = RenderCollection(
|
|
RecordVideo(gym.make("CartPole-v1", render_mode="rgb_array"), "videos")
|
|
)
|
|
env.reset(seed=123)
|
|
env.action_space.seed(123)
|
|
for _ in range(n_steps):
|
|
env.step(env.action_space.sample())
|
|
|
|
render_out = env.render()
|
|
assert isinstance(render_out, list)
|
|
assert len(render_out) == n_steps + 1
|
|
assert all(isinstance(render, np.ndarray) for render in render_out)
|
|
assert all(render.ndim == 3 for render in render_out)
|
|
|
|
render_out = env.render()
|
|
assert isinstance(render_out, list)
|
|
assert len(render_out) == 0
|
|
|
|
env.close()
|
|
shutil.rmtree("videos")
|