mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-18 04:49:12 +00:00
Add RecordVideo wrapper (#2300)
* Add RecordVideo wrapper * bug fix * don't change gym's core API * add test cases * reformat
This commit is contained in:
@@ -92,13 +92,14 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
self.env_fns = env_fns
|
self.env_fns = env_fns
|
||||||
self.shared_memory = shared_memory
|
self.shared_memory = shared_memory
|
||||||
self.copy = copy
|
self.copy = copy
|
||||||
|
dummy_env = env_fns[0]()
|
||||||
|
self.metadata = dummy_env.metadata
|
||||||
|
|
||||||
if (observation_space is None) or (action_space is None):
|
if (observation_space is None) or (action_space is None):
|
||||||
dummy_env = env_fns[0]()
|
|
||||||
observation_space = observation_space or dummy_env.observation_space
|
observation_space = observation_space or dummy_env.observation_space
|
||||||
action_space = action_space or dummy_env.action_space
|
action_space = action_space or dummy_env.action_space
|
||||||
dummy_env.close()
|
dummy_env.close()
|
||||||
del dummy_env
|
del dummy_env
|
||||||
super(AsyncVectorEnv, self).__init__(
|
super(AsyncVectorEnv, self).__init__(
|
||||||
num_envs=len(env_fns),
|
num_envs=len(env_fns),
|
||||||
observation_space=observation_space,
|
observation_space=observation_space,
|
||||||
|
@@ -33,6 +33,7 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
self.env_fns = env_fns
|
self.env_fns = env_fns
|
||||||
self.envs = [env_fn() for env_fn in env_fns]
|
self.envs = [env_fn() for env_fn in env_fns]
|
||||||
self.copy = copy
|
self.copy = copy
|
||||||
|
self.metadata = self.envs[0].metadata
|
||||||
|
|
||||||
if (observation_space is None) or (action_space is None):
|
if (observation_space is None) or (action_space is None):
|
||||||
observation_space = observation_space or self.envs[0].observation_space
|
observation_space = observation_space or self.envs[0].observation_space
|
||||||
|
@@ -32,6 +32,7 @@ class VectorEnv(gym.Env):
|
|||||||
def __init__(self, num_envs, observation_space, action_space):
|
def __init__(self, num_envs, observation_space, action_space):
|
||||||
super(VectorEnv, self).__init__()
|
super(VectorEnv, self).__init__()
|
||||||
self.num_envs = num_envs
|
self.num_envs = num_envs
|
||||||
|
self.is_vector_env = True
|
||||||
self.observation_space = batch_space(observation_space, n=num_envs)
|
self.observation_space = batch_space(observation_space, n=num_envs)
|
||||||
self.action_space = Tuple((action_space,) * num_envs)
|
self.action_space = Tuple((action_space,) * num_envs)
|
||||||
|
|
||||||
|
@@ -14,3 +14,4 @@ from gym.wrappers.transform_reward import TransformReward
|
|||||||
from gym.wrappers.resize_observation import ResizeObservation
|
from gym.wrappers.resize_observation import ResizeObservation
|
||||||
from gym.wrappers.clip_action import ClipAction
|
from gym.wrappers.clip_action import ClipAction
|
||||||
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
|
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
|
||||||
|
from gym.wrappers.record_video import RecordVideo
|
||||||
|
@@ -7,7 +7,6 @@ import gym
|
|||||||
class RecordEpisodeStatistics(gym.Wrapper):
|
class RecordEpisodeStatistics(gym.Wrapper):
|
||||||
def __init__(self, env, deque_size=100):
|
def __init__(self, env, deque_size=100):
|
||||||
super(RecordEpisodeStatistics, self).__init__(env)
|
super(RecordEpisodeStatistics, self).__init__(env)
|
||||||
self.env_is_vec = isinstance(env, gym.vector.VectorEnv)
|
|
||||||
self.num_envs = getattr(env, "num_envs", 1)
|
self.num_envs = getattr(env, "num_envs", 1)
|
||||||
self.t0 = (
|
self.t0 = (
|
||||||
time.time()
|
time.time()
|
||||||
@@ -17,6 +16,7 @@ class RecordEpisodeStatistics(gym.Wrapper):
|
|||||||
self.episode_lengths = None
|
self.episode_lengths = None
|
||||||
self.return_queue = deque(maxlen=deque_size)
|
self.return_queue = deque(maxlen=deque_size)
|
||||||
self.length_queue = deque(maxlen=deque_size)
|
self.length_queue = deque(maxlen=deque_size)
|
||||||
|
self.is_vector_env = getattr(env, "is_vector_env", False)
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
observations = super(RecordEpisodeStatistics, self).reset(**kwargs)
|
observations = super(RecordEpisodeStatistics, self).reset(**kwargs)
|
||||||
@@ -30,7 +30,7 @@ class RecordEpisodeStatistics(gym.Wrapper):
|
|||||||
)
|
)
|
||||||
self.episode_returns += rewards
|
self.episode_returns += rewards
|
||||||
self.episode_lengths += 1
|
self.episode_lengths += 1
|
||||||
if not self.env_is_vec:
|
if not self.is_vector_env:
|
||||||
infos = [infos]
|
infos = [infos]
|
||||||
dones = [dones]
|
dones = [dones]
|
||||||
for i in range(len(dones)):
|
for i in range(len(dones)):
|
||||||
@@ -52,6 +52,6 @@ class RecordEpisodeStatistics(gym.Wrapper):
|
|||||||
return (
|
return (
|
||||||
observations,
|
observations,
|
||||||
rewards,
|
rewards,
|
||||||
dones if self.env_is_vec else dones[0],
|
dones if self.is_vector_env else dones[0],
|
||||||
infos if self.env_is_vec else infos[0],
|
infos if self.is_vector_env else infos[0],
|
||||||
)
|
)
|
||||||
|
79
gym/wrappers/record_video.py
Normal file
79
gym/wrappers/record_video.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import os
|
||||||
|
import gym
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from gym.wrappers.monitoring import video_recorder
|
||||||
|
|
||||||
|
|
||||||
|
class RecordVideo(gym.Wrapper):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env,
|
||||||
|
video_folder: str,
|
||||||
|
record_video_trigger: Callable[[int], bool],
|
||||||
|
video_length: int = 0,
|
||||||
|
name_prefix: str = "rl-video",
|
||||||
|
):
|
||||||
|
super(RecordVideo, self).__init__(env)
|
||||||
|
self.record_video_trigger = record_video_trigger
|
||||||
|
self.video_recorder = None
|
||||||
|
|
||||||
|
self.video_folder = os.path.abspath(video_folder)
|
||||||
|
# Create output folder if needed
|
||||||
|
os.makedirs(self.video_folder, exist_ok=True)
|
||||||
|
|
||||||
|
self.name_prefix = name_prefix
|
||||||
|
self.step_id = 0
|
||||||
|
self.video_length = video_length
|
||||||
|
|
||||||
|
self.recording = False
|
||||||
|
self.recorded_frames = 0
|
||||||
|
self.is_vector_env = getattr(env, "is_vector_env", False)
|
||||||
|
|
||||||
|
def reset(self, **kwargs):
|
||||||
|
observations = super(RecordVideo, self).reset(**kwargs)
|
||||||
|
self.start_video_recorder()
|
||||||
|
return observations
|
||||||
|
|
||||||
|
def start_video_recorder(self):
|
||||||
|
self.close_video_recorder()
|
||||||
|
|
||||||
|
video_name = f"{self.name_prefix}-step-{self.step_id}"
|
||||||
|
base_path = os.path.join(self.video_folder, video_name)
|
||||||
|
self.video_recorder = video_recorder.VideoRecorder(
|
||||||
|
env=self.env, base_path=base_path, metadata={"step_id": self.step_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.video_recorder.capture_frame()
|
||||||
|
self.recorded_frames = 1
|
||||||
|
self.recording = True
|
||||||
|
|
||||||
|
def _video_enabled(self):
|
||||||
|
return self.record_video_trigger(self.step_id)
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
observations, rewards, dones, infos = super(RecordVideo, self).step(action)
|
||||||
|
|
||||||
|
self.step_id += 1
|
||||||
|
if self.recording:
|
||||||
|
self.video_recorder.capture_frame()
|
||||||
|
self.recorded_frames += 1
|
||||||
|
if self.video_length > 0:
|
||||||
|
if self.recorded_frames > self.video_length:
|
||||||
|
self.close_video_recorder()
|
||||||
|
else:
|
||||||
|
if not self.is_vector_env:
|
||||||
|
dones = [dones]
|
||||||
|
if dones[0]:
|
||||||
|
self.close_video_recorder()
|
||||||
|
|
||||||
|
elif self._video_enabled():
|
||||||
|
self.start_video_recorder()
|
||||||
|
|
||||||
|
return observations, rewards, dones if self.is_vector_env else dones[0], infos
|
||||||
|
|
||||||
|
def close_video_recorder(self) -> None:
|
||||||
|
if self.recording:
|
||||||
|
self.video_recorder.close()
|
||||||
|
self.recording = False
|
||||||
|
self.recorded_frames = 1
|
50
gym/wrappers/test_record_video.py
Normal file
50
gym/wrappers/test_record_video.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import gym
|
||||||
|
from gym.wrappers import RecordEpisodeStatistics, RecordVideo
|
||||||
|
|
||||||
|
|
||||||
|
def test_record_video():
|
||||||
|
env = gym.make("CartPole-v1")
|
||||||
|
env = gym.wrappers.RecordVideo(
|
||||||
|
env, "videos", record_video_trigger=lambda x: x % 100 == 0
|
||||||
|
)
|
||||||
|
env.reset()
|
||||||
|
for _ in range(200):
|
||||||
|
action = env.action_space.sample()
|
||||||
|
_, _, done, _ = env.step(action)
|
||||||
|
if done:
|
||||||
|
env.close()
|
||||||
|
break
|
||||||
|
assert os.path.isdir("videos")
|
||||||
|
shutil.rmtree("videos")
|
||||||
|
|
||||||
|
|
||||||
|
def make_env(gym_id, seed):
|
||||||
|
def thunk():
|
||||||
|
env = gym.make(gym_id)
|
||||||
|
env.seed(seed)
|
||||||
|
env.action_space.seed(seed)
|
||||||
|
env.observation_space.seed(seed)
|
||||||
|
env = gym.wrappers.RecordVideo(
|
||||||
|
env, "videos", record_video_trigger=lambda x: x % 100 == 0
|
||||||
|
)
|
||||||
|
return env
|
||||||
|
|
||||||
|
return thunk
|
||||||
|
|
||||||
|
|
||||||
|
def test_record_video_vector():
|
||||||
|
envs = gym.vector.SyncVectorEnv([make_env("CartPole-v1", 1 + i) for i in range(2)])
|
||||||
|
envs = gym.wrappers.RecordEpisodeStatistics(envs)
|
||||||
|
envs.reset()
|
||||||
|
for i in range(100):
|
||||||
|
_, _, _, infos = envs.step(envs.action_space.sample())
|
||||||
|
for info in infos:
|
||||||
|
if "episode" in info.keys():
|
||||||
|
print(f"i, episode_reward={info['episode']['r']}")
|
||||||
|
break
|
||||||
|
assert os.path.isdir("videos")
|
||||||
|
shutil.rmtree("videos")
|
Reference in New Issue
Block a user