Add stacktrace reporting to AsyncVectorEnv (#1119)

This commit is contained in:
Mark Towers
2024-07-15 15:53:11 +01:00
committed by GitHub
parent 020a7442c6
commit 992638e120
3 changed files with 99 additions and 31 deletions

View File

@@ -1,17 +1,13 @@
"""Set of functions for logging messages."""
import sys
import warnings
from typing import Optional, Type
from gymnasium.utils import colorize
DEBUG = 10
INFO = 20
WARN = 30
ERROR = 40
DISABLED = 50
min_level = 30
@@ -20,24 +16,6 @@ min_level = 30
warnings.filterwarnings("once", "", DeprecationWarning, module=r"^gymnasium\.")
def set_level(level: int):
"""Set logging threshold on current logger."""
global min_level
min_level = level
def debug(msg: str, *args: object):
"""Logs a debug message to the user."""
if min_level <= DEBUG:
print(f"DEBUG: {msg % args}", file=sys.stderr)
def info(msg: str, *args: object):
"""Logs an info message to the user."""
if min_level <= INFO:
print(f"INFO: {msg % args}", file=sys.stderr)
def warn(
msg: str,
*args: object,
@@ -68,8 +46,4 @@ def deprecation(msg: str, *args: object):
def error(msg: str, *args: object):
"""Logs an error message if min_level <= ERROR in red on the sys.stderr."""
if min_level <= ERROR:
print(colorize(f"ERROR: {msg % args}", "red"), file=sys.stderr)
# DEPRECATED:
setLevel = set_level
warnings.warn(colorize(f"ERROR: {msg % args}", "red"), stacklevel=3)

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import multiprocessing
import sys
import time
import traceback
from copy import deepcopy
from enum import Enum
from multiprocessing import Queue
@@ -623,18 +624,19 @@ class AsyncVectorEnv(VectorEnv):
num_errors = self.num_envs - sum(successes)
assert num_errors > 0
for i in range(num_errors):
index, exctype, value = self.error_queue.get()
index, exctype, value, trace = self.error_queue.get()
logger.error(
f"Received the following error from Worker-{index}: {exctype.__name__}: {value}"
f"Received the following error from Worker-{index} - Shutting it down"
)
logger.error(f"Shutting down Worker-{index}.")
logger.error(f"{trace}")
self.parent_pipes[index].close()
self.parent_pipes[index] = None
if i == num_errors - 1:
logger.error("Raising the last exception back to the main process.")
self._state = AsyncState.DEFAULT
raise exctype(value)
def __del__(self):
@@ -723,7 +725,10 @@ def _async_worker(
f"Received unknown command `{command}`. Must be one of [`reset`, `step`, `close`, `_call`, `_setattr`, `_check_spaces`]."
)
except (KeyboardInterrupt, Exception):
error_queue.put((index,) + sys.exc_info()[:2])
error_type, error_message, _ = sys.exc_info()
trace = traceback.format_exc()
error_queue.put((index, error_type, error_message, trace))
pipe.send((None, False))
finally:
env.close()

View File

@@ -1,6 +1,7 @@
"""Test the `SyncVectorEnv` implementation."""
import re
import warnings
from multiprocessing import TimeoutError
import numpy as np
@@ -13,6 +14,7 @@ from gymnasium.error import (
)
from gymnasium.spaces import Box, Discrete, MultiDiscrete, Tuple
from gymnasium.vector import AsyncVectorEnv
from tests.testing_env import GenericTestEnv
from tests.vector.testing_utils import (
CustomSpace,
make_custom_space_env,
@@ -345,3 +347,90 @@ def test_custom_space_async_vector_env_shared_memory():
with pytest.raises(ValueError):
env = AsyncVectorEnv(env_fns, shared_memory=True)
env.close(terminate=True)
def raise_error_reset(self, seed, options):
super(GenericTestEnv, self).reset(seed=seed, options=options)
if seed == 1:
raise ValueError("Error in reset")
return self.observation_space.sample(), {}
def raise_error_step(self, action):
if action >= 1:
raise ValueError(f"Error in step with {action}")
return self.observation_space.sample(), 0, False, False, {}
def test_async_vector_subenv_error():
envs = AsyncVectorEnv(
[
lambda: GenericTestEnv(
reset_func=raise_error_reset, step_func=raise_error_step
)
]
* 2
)
with warnings.catch_warnings(record=True) as caught_warnings:
envs.reset(seed=[0, 0])
assert len(caught_warnings) == 0
with warnings.catch_warnings(record=True) as caught_warnings:
with pytest.raises(ValueError, match="Error in reset"):
envs.reset(seed=[1, 0])
envs.close()
assert len(caught_warnings) == 3
assert (
"Received the following error from Worker-0 - Shutting it down"
in caught_warnings[0].message.args[0]
)
assert (
'in raise_error_reset\n raise ValueError("Error in reset")\nValueError: Error in reset'
in caught_warnings[1].message.args[0]
)
assert (
caught_warnings[2].message.args[0]
== "\x1b[31mERROR: Raising the last exception back to the main process.\x1b[0m"
)
envs = AsyncVectorEnv(
[
lambda: GenericTestEnv(
reset_func=raise_error_reset, step_func=raise_error_step
)
]
* 3
)
with warnings.catch_warnings(record=True) as caught_warnings:
with pytest.raises(ValueError, match="Error in step"):
envs.step([0, 1, 2])
envs.close()
assert len(caught_warnings) == 5
# due to variance in the step time, the order of warnings is random
assert re.match(
r"\x1b\[31mERROR: Received the following error from Worker-[12] - Shutting it down\x1b\[0m",
caught_warnings[0].message.args[0],
)
assert re.match(
r"\x1b\[31mERROR: Traceback \(most recent call last\):(?s:.)*in raise_error_step(?s:.)*ValueError: Error in step with [12]\n\x1b\[0m",
caught_warnings[1].message.args[0],
)
assert re.match(
r"\x1b\[31mERROR: Received the following error from Worker-[12] - Shutting it down\x1b\[0m",
caught_warnings[2].message.args[0],
)
assert re.match(
r"\x1b\[31mERROR: Traceback \(most recent call last\):(?s:.)*in raise_error_step(?s:.)*ValueError: Error in step with [12]\n\x1b\[0m",
caught_warnings[3].message.args[0],
)
assert (
caught_warnings[4].message.args[0]
== "\x1b[31mERROR: Raising the last exception back to the main process.\x1b[0m"
)