mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-03 06:40:46 +00:00
Co-authored-by: Kallinteris Andreas <30759571+Kallinteris-Andreas@users.noreply.github.com> Co-authored-by: Jet <38184875+jjshoots@users.noreply.github.com> Co-authored-by: Omar Younis <42100908+younik@users.noreply.github.com>
121 lines
3.5 KiB
Python
121 lines
3.5 KiB
Python
import os
|
|
import shutil
|
|
|
|
import numpy as np
|
|
|
|
import gymnasium as gym
|
|
from gymnasium.utils.save_video import capped_cubic_video_schedule, save_video
|
|
|
|
|
|
def test_record_video_using_default_trigger():
|
|
env = gym.make(
|
|
"CartPole-v1", render_mode="rgb_array_list", disable_env_checker=True
|
|
)
|
|
|
|
env.reset()
|
|
step_starting_index = 0
|
|
episode_index = 0
|
|
for step_index in range(199):
|
|
action = env.action_space.sample()
|
|
_, _, terminated, truncated, _ = env.step(action)
|
|
if terminated or truncated:
|
|
save_video(
|
|
env.render(),
|
|
"videos",
|
|
fps=env.metadata["render_fps"],
|
|
step_starting_index=step_starting_index,
|
|
episode_index=episode_index,
|
|
)
|
|
step_starting_index = step_index + 1
|
|
episode_index += 1
|
|
env.reset()
|
|
|
|
env.close()
|
|
assert os.path.isdir("videos")
|
|
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
|
|
shutil.rmtree("videos")
|
|
assert len(mp4_files) == sum(
|
|
capped_cubic_video_schedule(i) for i in range(episode_index)
|
|
)
|
|
|
|
|
|
def modulo_step_trigger(mod: int):
|
|
def step_trigger(step_index):
|
|
return step_index % mod == 0
|
|
|
|
return step_trigger
|
|
|
|
|
|
def test_record_video_step_trigger():
|
|
env = gym.make("CartPole-v1", render_mode="rgb_array_list")
|
|
env._max_episode_steps = 20
|
|
|
|
env.reset()
|
|
step_starting_index = 0
|
|
episode_index = 0
|
|
for step_index in range(199):
|
|
action = env.action_space.sample()
|
|
_, _, terminated, truncated, _ = env.step(action)
|
|
if terminated or truncated:
|
|
save_video(
|
|
env.render(),
|
|
"videos",
|
|
fps=env.metadata["render_fps"],
|
|
step_trigger=modulo_step_trigger(100),
|
|
step_starting_index=step_starting_index,
|
|
episode_index=episode_index,
|
|
)
|
|
step_starting_index = step_index + 1
|
|
episode_index += 1
|
|
env.reset()
|
|
env.close()
|
|
|
|
assert os.path.isdir("videos")
|
|
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
|
|
shutil.rmtree("videos")
|
|
assert len(mp4_files) == 2
|
|
|
|
|
|
def test_record_video_within_vector():
|
|
step_trigger = modulo_step_trigger(100)
|
|
n_steps = 199
|
|
expected_video = 2
|
|
|
|
envs = gym.make_vec(
|
|
"CartPole-v1",
|
|
num_envs=2,
|
|
vectorization_mode="sync",
|
|
render_mode="rgb_array_list",
|
|
)
|
|
envs.reset()
|
|
episode_frames = []
|
|
step_starting_index = 0
|
|
episode_index = 0
|
|
for step_index in range(n_steps):
|
|
_, _, terminated, truncated, _ = envs.step(envs.action_space.sample())
|
|
episode_frames.extend(envs.call("render")[0])
|
|
|
|
if np.any(np.logical_or(terminated, truncated)):
|
|
save_video(
|
|
episode_frames,
|
|
"videos",
|
|
fps=envs.metadata["render_fps"],
|
|
step_trigger=step_trigger,
|
|
step_starting_index=step_starting_index,
|
|
episode_index=episode_index,
|
|
)
|
|
episode_frames = []
|
|
step_starting_index = step_index + 1
|
|
episode_index += 1
|
|
|
|
# TODO: fix this test (see https://github.com/openai/gym/issues/3054)
|
|
if step_trigger(step_index):
|
|
expected_video -= 1
|
|
|
|
envs.close()
|
|
|
|
assert os.path.isdir("videos")
|
|
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
|
|
shutil.rmtree("videos")
|
|
assert len(mp4_files) == expected_video
|