import os import shutil import numpy as np import gymnasium as gym from gymnasium.utils.save_video import capped_cubic_video_schedule, save_video def test_record_video_using_default_trigger(): env = gym.make( "CartPole-v1", render_mode="rgb_array_list", disable_env_checker=True ) env.reset() step_starting_index = 0 episode_index = 0 for step_index in range(199): action = env.action_space.sample() _, _, terminated, truncated, _ = env.step(action) if terminated or truncated: save_video( env.render(), "videos", fps=env.metadata["render_fps"], step_starting_index=step_starting_index, episode_index=episode_index, ) step_starting_index = step_index + 1 episode_index += 1 env.reset() env.close() assert os.path.isdir("videos") mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] shutil.rmtree("videos") assert len(mp4_files) == sum( capped_cubic_video_schedule(i) for i in range(episode_index) ) def modulo_step_trigger(mod: int): def step_trigger(step_index): return step_index % mod == 0 return step_trigger def test_record_video_step_trigger(): env = gym.make("CartPole-v1", render_mode="rgb_array_list") env._max_episode_steps = 20 env.reset() step_starting_index = 0 episode_index = 0 for step_index in range(199): action = env.action_space.sample() _, _, terminated, truncated, _ = env.step(action) if terminated or truncated: save_video( env.render(), "videos", fps=env.metadata["render_fps"], step_trigger=modulo_step_trigger(100), step_starting_index=step_starting_index, episode_index=episode_index, ) step_starting_index = step_index + 1 episode_index += 1 env.reset() env.close() assert os.path.isdir("videos") mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] shutil.rmtree("videos") assert len(mp4_files) == 2 def test_record_video_within_vector(): step_trigger = modulo_step_trigger(100) n_steps = 199 expected_video = 2 envs = gym.make_vec( "CartPole-v1", num_envs=2, vectorization_mode="sync", render_mode="rgb_array_list", ) envs.reset() episode_frames = [] step_starting_index = 0 episode_index = 0 for step_index in range(n_steps): _, _, terminated, truncated, _ = envs.step(envs.action_space.sample()) episode_frames.extend(envs.call("render")[0]) if np.any(np.logical_or(terminated, truncated)): save_video( episode_frames, "videos", fps=envs.metadata["render_fps"], step_trigger=step_trigger, step_starting_index=step_starting_index, episode_index=episode_index, ) episode_frames = [] step_starting_index = step_index + 1 episode_index += 1 # TODO: fix this test (see https://github.com/openai/gym/issues/3054) if step_trigger(step_index): expected_video -= 1 envs.close() assert os.path.isdir("videos") mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] shutil.rmtree("videos") assert len(mp4_files) == expected_video