mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-30 17:57:30 +00:00
Add stacktrace reporting to AsyncVectorEnv
(#1119)
This commit is contained in:
@@ -1,17 +1,13 @@
|
||||
"""Set of functions for logging messages."""
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Optional, Type
|
||||
|
||||
from gymnasium.utils import colorize
|
||||
|
||||
|
||||
DEBUG = 10
|
||||
INFO = 20
|
||||
WARN = 30
|
||||
ERROR = 40
|
||||
DISABLED = 50
|
||||
|
||||
min_level = 30
|
||||
|
||||
@@ -20,24 +16,6 @@ min_level = 30
|
||||
warnings.filterwarnings("once", "", DeprecationWarning, module=r"^gymnasium\.")
|
||||
|
||||
|
||||
def set_level(level: int):
|
||||
"""Set logging threshold on current logger."""
|
||||
global min_level
|
||||
min_level = level
|
||||
|
||||
|
||||
def debug(msg: str, *args: object):
|
||||
"""Logs a debug message to the user."""
|
||||
if min_level <= DEBUG:
|
||||
print(f"DEBUG: {msg % args}", file=sys.stderr)
|
||||
|
||||
|
||||
def info(msg: str, *args: object):
|
||||
"""Logs an info message to the user."""
|
||||
if min_level <= INFO:
|
||||
print(f"INFO: {msg % args}", file=sys.stderr)
|
||||
|
||||
|
||||
def warn(
|
||||
msg: str,
|
||||
*args: object,
|
||||
@@ -68,8 +46,4 @@ def deprecation(msg: str, *args: object):
|
||||
def error(msg: str, *args: object):
|
||||
"""Logs an error message if min_level <= ERROR in red on the sys.stderr."""
|
||||
if min_level <= ERROR:
|
||||
print(colorize(f"ERROR: {msg % args}", "red"), file=sys.stderr)
|
||||
|
||||
|
||||
# DEPRECATED:
|
||||
setLevel = set_level
|
||||
warnings.warn(colorize(f"ERROR: {msg % args}", "red"), stacklevel=3)
|
||||
|
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import multiprocessing
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from multiprocessing import Queue
|
||||
@@ -623,18 +624,19 @@ class AsyncVectorEnv(VectorEnv):
|
||||
num_errors = self.num_envs - sum(successes)
|
||||
assert num_errors > 0
|
||||
for i in range(num_errors):
|
||||
index, exctype, value = self.error_queue.get()
|
||||
index, exctype, value, trace = self.error_queue.get()
|
||||
|
||||
logger.error(
|
||||
f"Received the following error from Worker-{index}: {exctype.__name__}: {value}"
|
||||
f"Received the following error from Worker-{index} - Shutting it down"
|
||||
)
|
||||
logger.error(f"Shutting down Worker-{index}.")
|
||||
logger.error(f"{trace}")
|
||||
|
||||
self.parent_pipes[index].close()
|
||||
self.parent_pipes[index] = None
|
||||
|
||||
if i == num_errors - 1:
|
||||
logger.error("Raising the last exception back to the main process.")
|
||||
self._state = AsyncState.DEFAULT
|
||||
raise exctype(value)
|
||||
|
||||
def __del__(self):
|
||||
@@ -723,7 +725,10 @@ def _async_worker(
|
||||
f"Received unknown command `{command}`. Must be one of [`reset`, `step`, `close`, `_call`, `_setattr`, `_check_spaces`]."
|
||||
)
|
||||
except (KeyboardInterrupt, Exception):
|
||||
error_queue.put((index,) + sys.exc_info()[:2])
|
||||
error_type, error_message, _ = sys.exc_info()
|
||||
trace = traceback.format_exc()
|
||||
|
||||
error_queue.put((index, error_type, error_message, trace))
|
||||
pipe.send((None, False))
|
||||
finally:
|
||||
env.close()
|
||||
|
@@ -1,6 +1,7 @@
|
||||
"""Test the `SyncVectorEnv` implementation."""
|
||||
|
||||
import re
|
||||
import warnings
|
||||
from multiprocessing import TimeoutError
|
||||
|
||||
import numpy as np
|
||||
@@ -13,6 +14,7 @@ from gymnasium.error import (
|
||||
)
|
||||
from gymnasium.spaces import Box, Discrete, MultiDiscrete, Tuple
|
||||
from gymnasium.vector import AsyncVectorEnv
|
||||
from tests.testing_env import GenericTestEnv
|
||||
from tests.vector.testing_utils import (
|
||||
CustomSpace,
|
||||
make_custom_space_env,
|
||||
@@ -345,3 +347,90 @@ def test_custom_space_async_vector_env_shared_memory():
|
||||
with pytest.raises(ValueError):
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=True)
|
||||
env.close(terminate=True)
|
||||
|
||||
|
||||
def raise_error_reset(self, seed, options):
|
||||
super(GenericTestEnv, self).reset(seed=seed, options=options)
|
||||
if seed == 1:
|
||||
raise ValueError("Error in reset")
|
||||
return self.observation_space.sample(), {}
|
||||
|
||||
|
||||
def raise_error_step(self, action):
|
||||
if action >= 1:
|
||||
raise ValueError(f"Error in step with {action}")
|
||||
|
||||
return self.observation_space.sample(), 0, False, False, {}
|
||||
|
||||
|
||||
def test_async_vector_subenv_error():
|
||||
envs = AsyncVectorEnv(
|
||||
[
|
||||
lambda: GenericTestEnv(
|
||||
reset_func=raise_error_reset, step_func=raise_error_step
|
||||
)
|
||||
]
|
||||
* 2
|
||||
)
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
envs.reset(seed=[0, 0])
|
||||
assert len(caught_warnings) == 0
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
with pytest.raises(ValueError, match="Error in reset"):
|
||||
envs.reset(seed=[1, 0])
|
||||
|
||||
envs.close()
|
||||
|
||||
assert len(caught_warnings) == 3
|
||||
assert (
|
||||
"Received the following error from Worker-0 - Shutting it down"
|
||||
in caught_warnings[0].message.args[0]
|
||||
)
|
||||
assert (
|
||||
'in raise_error_reset\n raise ValueError("Error in reset")\nValueError: Error in reset'
|
||||
in caught_warnings[1].message.args[0]
|
||||
)
|
||||
assert (
|
||||
caught_warnings[2].message.args[0]
|
||||
== "\x1b[31mERROR: Raising the last exception back to the main process.\x1b[0m"
|
||||
)
|
||||
|
||||
envs = AsyncVectorEnv(
|
||||
[
|
||||
lambda: GenericTestEnv(
|
||||
reset_func=raise_error_reset, step_func=raise_error_step
|
||||
)
|
||||
]
|
||||
* 3
|
||||
)
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
with pytest.raises(ValueError, match="Error in step"):
|
||||
envs.step([0, 1, 2])
|
||||
|
||||
envs.close()
|
||||
|
||||
assert len(caught_warnings) == 5
|
||||
# due to variance in the step time, the order of warnings is random
|
||||
assert re.match(
|
||||
r"\x1b\[31mERROR: Received the following error from Worker-[12] - Shutting it down\x1b\[0m",
|
||||
caught_warnings[0].message.args[0],
|
||||
)
|
||||
assert re.match(
|
||||
r"\x1b\[31mERROR: Traceback \(most recent call last\):(?s:.)*in raise_error_step(?s:.)*ValueError: Error in step with [12]\n\x1b\[0m",
|
||||
caught_warnings[1].message.args[0],
|
||||
)
|
||||
assert re.match(
|
||||
r"\x1b\[31mERROR: Received the following error from Worker-[12] - Shutting it down\x1b\[0m",
|
||||
caught_warnings[2].message.args[0],
|
||||
)
|
||||
assert re.match(
|
||||
r"\x1b\[31mERROR: Traceback \(most recent call last\):(?s:.)*in raise_error_step(?s:.)*ValueError: Error in step with [12]\n\x1b\[0m",
|
||||
caught_warnings[3].message.args[0],
|
||||
)
|
||||
assert (
|
||||
caught_warnings[4].message.args[0]
|
||||
== "\x1b[31mERROR: Raising the last exception back to the main process.\x1b[0m"
|
||||
)
|
||||
|
Reference in New Issue
Block a user