mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-02 14:26:33 +00:00
Added garbage collector on RecordVideo wrapper (#1378)
This commit is contained in:
@@ -9,6 +9,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, List, SupportsFloat
|
||||
@@ -241,6 +242,7 @@ class RecordVideo(
|
||||
name_prefix: str = "rl-video",
|
||||
fps: int | None = None,
|
||||
disable_logger: bool = True,
|
||||
gc_trigger: Callable[[int], bool] | None = lambda episode: True,
|
||||
):
|
||||
"""Wrapper records videos of rollouts.
|
||||
|
||||
@@ -255,6 +257,7 @@ class RecordVideo(
|
||||
fps (int): The frame per second in the video. Provides a custom video fps for environment, if ``None`` then
|
||||
the environment metadata ``render_fps`` key is used if it exists, otherwise a default value of 30 is used.
|
||||
disable_logger (bool): Whether to disable moviepy logger or not, default it is disabled
|
||||
gc_trigger: Function that accepts an integer and returns ``True`` iff garbage collection should be performed after this episode
|
||||
"""
|
||||
gym.utils.RecordConstructorArgs.__init__(
|
||||
self,
|
||||
@@ -281,6 +284,7 @@ class RecordVideo(
|
||||
self.episode_trigger = episode_trigger
|
||||
self.step_trigger = step_trigger
|
||||
self.disable_logger = disable_logger
|
||||
self.gc_trigger = gc_trigger
|
||||
|
||||
self.video_folder = os.path.abspath(video_folder)
|
||||
if os.path.isdir(self.video_folder):
|
||||
@@ -414,6 +418,9 @@ class RecordVideo(
|
||||
self.recording = False
|
||||
self._video_name = None
|
||||
|
||||
if self.gc_trigger and self.gc_trigger(self.episode_id):
|
||||
gc.collect()
|
||||
|
||||
def __del__(self):
|
||||
"""Warn the user in case last video wasn't saved."""
|
||||
if len(self.recorded_frames) > 0:
|
||||
|
Reference in New Issue
Block a user