Added garbage collector on RecordVideo wrapper (#1378)

This commit is contained in:
vicbentu
2025-05-12 00:22:15 +02:00
committed by GitHub
parent 5dde9a79be
commit 271244dd49

View File

@@ -9,6 +9,7 @@
from __future__ import annotations
import gc
import os
from copy import deepcopy
from typing import Any, Callable, List, SupportsFloat
@@ -241,6 +242,7 @@ class RecordVideo(
name_prefix: str = "rl-video",
fps: int | None = None,
disable_logger: bool = True,
gc_trigger: Callable[[int], bool] | None = lambda episode: True,
):
"""Wrapper records videos of rollouts.
@@ -255,6 +257,7 @@ class RecordVideo(
fps (int): The frame per second in the video. Provides a custom video fps for environment, if ``None`` then
the environment metadata ``render_fps`` key is used if it exists, otherwise a default value of 30 is used.
disable_logger (bool): Whether to disable moviepy logger or not, default it is disabled
gc_trigger: Function that accepts an integer and returns ``True`` iff garbage collection should be performed after this episode
"""
gym.utils.RecordConstructorArgs.__init__(
self,
@@ -281,6 +284,7 @@ class RecordVideo(
self.episode_trigger = episode_trigger
self.step_trigger = step_trigger
self.disable_logger = disable_logger
self.gc_trigger = gc_trigger
self.video_folder = os.path.abspath(video_folder)
if os.path.isdir(self.video_folder):
@@ -414,6 +418,9 @@ class RecordVideo(
self.recording = False
self._video_name = None
if self.gc_trigger and self.gc_trigger(self.episode_id):
gc.collect()
def __del__(self):
"""Warn the user in case last video wasn't saved."""
if len(self.recorded_frames) > 0: