Files
Gymnasium/tests/utils/test_save_video.py
Mark Towers 27f8e85051 Merge v1.0.0 (#682)
Co-authored-by: Kallinteris Andreas <30759571+Kallinteris-Andreas@users.noreply.github.com>
Co-authored-by: Jet <38184875+jjshoots@users.noreply.github.com>
Co-authored-by: Omar Younis <42100908+younik@users.noreply.github.com>
2023-11-07 13:27:25 +00:00

121 lines
3.5 KiB
Python

import os
import shutil
import numpy as np
import gymnasium as gym
from gymnasium.utils.save_video import capped_cubic_video_schedule, save_video
def test_record_video_using_default_trigger():
env = gym.make(
"CartPole-v1", render_mode="rgb_array_list", disable_env_checker=True
)
env.reset()
step_starting_index = 0
episode_index = 0
for step_index in range(199):
action = env.action_space.sample()
_, _, terminated, truncated, _ = env.step(action)
if terminated or truncated:
save_video(
env.render(),
"videos",
fps=env.metadata["render_fps"],
step_starting_index=step_starting_index,
episode_index=episode_index,
)
step_starting_index = step_index + 1
episode_index += 1
env.reset()
env.close()
assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
shutil.rmtree("videos")
assert len(mp4_files) == sum(
capped_cubic_video_schedule(i) for i in range(episode_index)
)
def modulo_step_trigger(mod: int):
def step_trigger(step_index):
return step_index % mod == 0
return step_trigger
def test_record_video_step_trigger():
env = gym.make("CartPole-v1", render_mode="rgb_array_list")
env._max_episode_steps = 20
env.reset()
step_starting_index = 0
episode_index = 0
for step_index in range(199):
action = env.action_space.sample()
_, _, terminated, truncated, _ = env.step(action)
if terminated or truncated:
save_video(
env.render(),
"videos",
fps=env.metadata["render_fps"],
step_trigger=modulo_step_trigger(100),
step_starting_index=step_starting_index,
episode_index=episode_index,
)
step_starting_index = step_index + 1
episode_index += 1
env.reset()
env.close()
assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
shutil.rmtree("videos")
assert len(mp4_files) == 2
def test_record_video_within_vector():
step_trigger = modulo_step_trigger(100)
n_steps = 199
expected_video = 2
envs = gym.make_vec(
"CartPole-v1",
num_envs=2,
vectorization_mode="sync",
render_mode="rgb_array_list",
)
envs.reset()
episode_frames = []
step_starting_index = 0
episode_index = 0
for step_index in range(n_steps):
_, _, terminated, truncated, _ = envs.step(envs.action_space.sample())
episode_frames.extend(envs.call("render")[0])
if np.any(np.logical_or(terminated, truncated)):
save_video(
episode_frames,
"videos",
fps=envs.metadata["render_fps"],
step_trigger=step_trigger,
step_starting_index=step_starting_index,
episode_index=episode_index,
)
episode_frames = []
step_starting_index = step_index + 1
episode_index += 1
# TODO: fix this test (see https://github.com/openai/gym/issues/3054)
if step_trigger(step_index):
expected_video -= 1
envs.close()
assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
shutil.rmtree("videos")
assert len(mp4_files) == expected_video