mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-17 20:39:12 +00:00
Add RecordVideo wrapper (#2300)
* Add RecordVideo wrapper * bug fix * don't change gym's core API * add test cases * reformat
This commit is contained in:
@@ -92,9 +92,10 @@ class AsyncVectorEnv(VectorEnv):
|
||||
self.env_fns = env_fns
|
||||
self.shared_memory = shared_memory
|
||||
self.copy = copy
|
||||
dummy_env = env_fns[0]()
|
||||
self.metadata = dummy_env.metadata
|
||||
|
||||
if (observation_space is None) or (action_space is None):
|
||||
dummy_env = env_fns[0]()
|
||||
observation_space = observation_space or dummy_env.observation_space
|
||||
action_space = action_space or dummy_env.action_space
|
||||
dummy_env.close()
|
||||
|
@@ -33,6 +33,7 @@ class SyncVectorEnv(VectorEnv):
|
||||
self.env_fns = env_fns
|
||||
self.envs = [env_fn() for env_fn in env_fns]
|
||||
self.copy = copy
|
||||
self.metadata = self.envs[0].metadata
|
||||
|
||||
if (observation_space is None) or (action_space is None):
|
||||
observation_space = observation_space or self.envs[0].observation_space
|
||||
|
@@ -32,6 +32,7 @@ class VectorEnv(gym.Env):
|
||||
def __init__(self, num_envs, observation_space, action_space):
|
||||
super(VectorEnv, self).__init__()
|
||||
self.num_envs = num_envs
|
||||
self.is_vector_env = True
|
||||
self.observation_space = batch_space(observation_space, n=num_envs)
|
||||
self.action_space = Tuple((action_space,) * num_envs)
|
||||
|
||||
|
@@ -14,3 +14,4 @@ from gym.wrappers.transform_reward import TransformReward
|
||||
from gym.wrappers.resize_observation import ResizeObservation
|
||||
from gym.wrappers.clip_action import ClipAction
|
||||
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
|
||||
from gym.wrappers.record_video import RecordVideo
|
||||
|
@@ -7,7 +7,6 @@ import gym
|
||||
class RecordEpisodeStatistics(gym.Wrapper):
|
||||
def __init__(self, env, deque_size=100):
|
||||
super(RecordEpisodeStatistics, self).__init__(env)
|
||||
self.env_is_vec = isinstance(env, gym.vector.VectorEnv)
|
||||
self.num_envs = getattr(env, "num_envs", 1)
|
||||
self.t0 = (
|
||||
time.time()
|
||||
@@ -17,6 +16,7 @@ class RecordEpisodeStatistics(gym.Wrapper):
|
||||
self.episode_lengths = None
|
||||
self.return_queue = deque(maxlen=deque_size)
|
||||
self.length_queue = deque(maxlen=deque_size)
|
||||
self.is_vector_env = getattr(env, "is_vector_env", False)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
observations = super(RecordEpisodeStatistics, self).reset(**kwargs)
|
||||
@@ -30,7 +30,7 @@ class RecordEpisodeStatistics(gym.Wrapper):
|
||||
)
|
||||
self.episode_returns += rewards
|
||||
self.episode_lengths += 1
|
||||
if not self.env_is_vec:
|
||||
if not self.is_vector_env:
|
||||
infos = [infos]
|
||||
dones = [dones]
|
||||
for i in range(len(dones)):
|
||||
@@ -52,6 +52,6 @@ class RecordEpisodeStatistics(gym.Wrapper):
|
||||
return (
|
||||
observations,
|
||||
rewards,
|
||||
dones if self.env_is_vec else dones[0],
|
||||
infos if self.env_is_vec else infos[0],
|
||||
dones if self.is_vector_env else dones[0],
|
||||
infos if self.is_vector_env else infos[0],
|
||||
)
|
||||
|
79
gym/wrappers/record_video.py
Normal file
79
gym/wrappers/record_video.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import os
|
||||
import gym
|
||||
from typing import Callable
|
||||
|
||||
from gym.wrappers.monitoring import video_recorder
|
||||
|
||||
|
||||
class RecordVideo(gym.Wrapper):
|
||||
def __init__(
|
||||
self,
|
||||
env,
|
||||
video_folder: str,
|
||||
record_video_trigger: Callable[[int], bool],
|
||||
video_length: int = 0,
|
||||
name_prefix: str = "rl-video",
|
||||
):
|
||||
super(RecordVideo, self).__init__(env)
|
||||
self.record_video_trigger = record_video_trigger
|
||||
self.video_recorder = None
|
||||
|
||||
self.video_folder = os.path.abspath(video_folder)
|
||||
# Create output folder if needed
|
||||
os.makedirs(self.video_folder, exist_ok=True)
|
||||
|
||||
self.name_prefix = name_prefix
|
||||
self.step_id = 0
|
||||
self.video_length = video_length
|
||||
|
||||
self.recording = False
|
||||
self.recorded_frames = 0
|
||||
self.is_vector_env = getattr(env, "is_vector_env", False)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
observations = super(RecordVideo, self).reset(**kwargs)
|
||||
self.start_video_recorder()
|
||||
return observations
|
||||
|
||||
def start_video_recorder(self):
|
||||
self.close_video_recorder()
|
||||
|
||||
video_name = f"{self.name_prefix}-step-{self.step_id}"
|
||||
base_path = os.path.join(self.video_folder, video_name)
|
||||
self.video_recorder = video_recorder.VideoRecorder(
|
||||
env=self.env, base_path=base_path, metadata={"step_id": self.step_id}
|
||||
)
|
||||
|
||||
self.video_recorder.capture_frame()
|
||||
self.recorded_frames = 1
|
||||
self.recording = True
|
||||
|
||||
def _video_enabled(self):
|
||||
return self.record_video_trigger(self.step_id)
|
||||
|
||||
def step(self, action):
|
||||
observations, rewards, dones, infos = super(RecordVideo, self).step(action)
|
||||
|
||||
self.step_id += 1
|
||||
if self.recording:
|
||||
self.video_recorder.capture_frame()
|
||||
self.recorded_frames += 1
|
||||
if self.video_length > 0:
|
||||
if self.recorded_frames > self.video_length:
|
||||
self.close_video_recorder()
|
||||
else:
|
||||
if not self.is_vector_env:
|
||||
dones = [dones]
|
||||
if dones[0]:
|
||||
self.close_video_recorder()
|
||||
|
||||
elif self._video_enabled():
|
||||
self.start_video_recorder()
|
||||
|
||||
return observations, rewards, dones if self.is_vector_env else dones[0], infos
|
||||
|
||||
def close_video_recorder(self) -> None:
|
||||
if self.recording:
|
||||
self.video_recorder.close()
|
||||
self.recording = False
|
||||
self.recorded_frames = 1
|
50
gym/wrappers/test_record_video.py
Normal file
50
gym/wrappers/test_record_video.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import pytest
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import gym
|
||||
from gym.wrappers import RecordEpisodeStatistics, RecordVideo
|
||||
|
||||
|
||||
def test_record_video():
|
||||
env = gym.make("CartPole-v1")
|
||||
env = gym.wrappers.RecordVideo(
|
||||
env, "videos", record_video_trigger=lambda x: x % 100 == 0
|
||||
)
|
||||
env.reset()
|
||||
for _ in range(200):
|
||||
action = env.action_space.sample()
|
||||
_, _, done, _ = env.step(action)
|
||||
if done:
|
||||
env.close()
|
||||
break
|
||||
assert os.path.isdir("videos")
|
||||
shutil.rmtree("videos")
|
||||
|
||||
|
||||
def make_env(gym_id, seed):
|
||||
def thunk():
|
||||
env = gym.make(gym_id)
|
||||
env.seed(seed)
|
||||
env.action_space.seed(seed)
|
||||
env.observation_space.seed(seed)
|
||||
env = gym.wrappers.RecordVideo(
|
||||
env, "videos", record_video_trigger=lambda x: x % 100 == 0
|
||||
)
|
||||
return env
|
||||
|
||||
return thunk
|
||||
|
||||
|
||||
def test_record_video_vector():
|
||||
envs = gym.vector.SyncVectorEnv([make_env("CartPole-v1", 1 + i) for i in range(2)])
|
||||
envs = gym.wrappers.RecordEpisodeStatistics(envs)
|
||||
envs.reset()
|
||||
for i in range(100):
|
||||
_, _, _, infos = envs.step(envs.action_space.sample())
|
||||
for info in infos:
|
||||
if "episode" in info.keys():
|
||||
print(f"i, episode_reward={info['episode']['r']}")
|
||||
break
|
||||
assert os.path.isdir("videos")
|
||||
shutil.rmtree("videos")
|
Reference in New Issue
Block a user