Add stacktrace reporting to AsyncVectorEnv (#1119)

This commit is contained in:
Mark Towers
2024-07-15 15:53:11 +01:00
committed by GitHub
parent 020a7442c6
commit 992638e120
3 changed files with 99 additions and 31 deletions

View File

@@ -1,17 +1,13 @@
"""Set of functions for logging messages.""" """Set of functions for logging messages."""
import sys
import warnings import warnings
from typing import Optional, Type from typing import Optional, Type
from gymnasium.utils import colorize from gymnasium.utils import colorize
DEBUG = 10
INFO = 20
WARN = 30 WARN = 30
ERROR = 40 ERROR = 40
DISABLED = 50
min_level = 30 min_level = 30
@@ -20,24 +16,6 @@ min_level = 30
warnings.filterwarnings("once", "", DeprecationWarning, module=r"^gymnasium\.") warnings.filterwarnings("once", "", DeprecationWarning, module=r"^gymnasium\.")
def set_level(level: int):
"""Set logging threshold on current logger."""
global min_level
min_level = level
def debug(msg: str, *args: object):
"""Logs a debug message to the user."""
if min_level <= DEBUG:
print(f"DEBUG: {msg % args}", file=sys.stderr)
def info(msg: str, *args: object):
"""Logs an info message to the user."""
if min_level <= INFO:
print(f"INFO: {msg % args}", file=sys.stderr)
def warn( def warn(
msg: str, msg: str,
*args: object, *args: object,
@@ -68,8 +46,4 @@ def deprecation(msg: str, *args: object):
def error(msg: str, *args: object): def error(msg: str, *args: object):
"""Logs an error message if min_level <= ERROR in red on the sys.stderr.""" """Logs an error message if min_level <= ERROR in red on the sys.stderr."""
if min_level <= ERROR: if min_level <= ERROR:
print(colorize(f"ERROR: {msg % args}", "red"), file=sys.stderr) warnings.warn(colorize(f"ERROR: {msg % args}", "red"), stacklevel=3)
# DEPRECATED:
setLevel = set_level

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
import multiprocessing import multiprocessing
import sys import sys
import time import time
import traceback
from copy import deepcopy from copy import deepcopy
from enum import Enum from enum import Enum
from multiprocessing import Queue from multiprocessing import Queue
@@ -623,18 +624,19 @@ class AsyncVectorEnv(VectorEnv):
num_errors = self.num_envs - sum(successes) num_errors = self.num_envs - sum(successes)
assert num_errors > 0 assert num_errors > 0
for i in range(num_errors): for i in range(num_errors):
index, exctype, value = self.error_queue.get() index, exctype, value, trace = self.error_queue.get()
logger.error( logger.error(
f"Received the following error from Worker-{index}: {exctype.__name__}: {value}" f"Received the following error from Worker-{index} - Shutting it down"
) )
logger.error(f"Shutting down Worker-{index}.") logger.error(f"{trace}")
self.parent_pipes[index].close() self.parent_pipes[index].close()
self.parent_pipes[index] = None self.parent_pipes[index] = None
if i == num_errors - 1: if i == num_errors - 1:
logger.error("Raising the last exception back to the main process.") logger.error("Raising the last exception back to the main process.")
self._state = AsyncState.DEFAULT
raise exctype(value) raise exctype(value)
def __del__(self): def __del__(self):
@@ -723,7 +725,10 @@ def _async_worker(
f"Received unknown command `{command}`. Must be one of [`reset`, `step`, `close`, `_call`, `_setattr`, `_check_spaces`]." f"Received unknown command `{command}`. Must be one of [`reset`, `step`, `close`, `_call`, `_setattr`, `_check_spaces`]."
) )
except (KeyboardInterrupt, Exception): except (KeyboardInterrupt, Exception):
error_queue.put((index,) + sys.exc_info()[:2]) error_type, error_message, _ = sys.exc_info()
trace = traceback.format_exc()
error_queue.put((index, error_type, error_message, trace))
pipe.send((None, False)) pipe.send((None, False))
finally: finally:
env.close() env.close()

View File

@@ -1,6 +1,7 @@
"""Test the `SyncVectorEnv` implementation.""" """Test the `SyncVectorEnv` implementation."""
import re import re
import warnings
from multiprocessing import TimeoutError from multiprocessing import TimeoutError
import numpy as np import numpy as np
@@ -13,6 +14,7 @@ from gymnasium.error import (
) )
from gymnasium.spaces import Box, Discrete, MultiDiscrete, Tuple from gymnasium.spaces import Box, Discrete, MultiDiscrete, Tuple
from gymnasium.vector import AsyncVectorEnv from gymnasium.vector import AsyncVectorEnv
from tests.testing_env import GenericTestEnv
from tests.vector.testing_utils import ( from tests.vector.testing_utils import (
CustomSpace, CustomSpace,
make_custom_space_env, make_custom_space_env,
@@ -345,3 +347,90 @@ def test_custom_space_async_vector_env_shared_memory():
with pytest.raises(ValueError): with pytest.raises(ValueError):
env = AsyncVectorEnv(env_fns, shared_memory=True) env = AsyncVectorEnv(env_fns, shared_memory=True)
env.close(terminate=True) env.close(terminate=True)
def raise_error_reset(self, seed, options):
super(GenericTestEnv, self).reset(seed=seed, options=options)
if seed == 1:
raise ValueError("Error in reset")
return self.observation_space.sample(), {}
def raise_error_step(self, action):
if action >= 1:
raise ValueError(f"Error in step with {action}")
return self.observation_space.sample(), 0, False, False, {}
def test_async_vector_subenv_error():
envs = AsyncVectorEnv(
[
lambda: GenericTestEnv(
reset_func=raise_error_reset, step_func=raise_error_step
)
]
* 2
)
with warnings.catch_warnings(record=True) as caught_warnings:
envs.reset(seed=[0, 0])
assert len(caught_warnings) == 0
with warnings.catch_warnings(record=True) as caught_warnings:
with pytest.raises(ValueError, match="Error in reset"):
envs.reset(seed=[1, 0])
envs.close()
assert len(caught_warnings) == 3
assert (
"Received the following error from Worker-0 - Shutting it down"
in caught_warnings[0].message.args[0]
)
assert (
'in raise_error_reset\n raise ValueError("Error in reset")\nValueError: Error in reset'
in caught_warnings[1].message.args[0]
)
assert (
caught_warnings[2].message.args[0]
== "\x1b[31mERROR: Raising the last exception back to the main process.\x1b[0m"
)
envs = AsyncVectorEnv(
[
lambda: GenericTestEnv(
reset_func=raise_error_reset, step_func=raise_error_step
)
]
* 3
)
with warnings.catch_warnings(record=True) as caught_warnings:
with pytest.raises(ValueError, match="Error in step"):
envs.step([0, 1, 2])
envs.close()
assert len(caught_warnings) == 5
# due to variance in the step time, the order of warnings is random
assert re.match(
r"\x1b\[31mERROR: Received the following error from Worker-[12] - Shutting it down\x1b\[0m",
caught_warnings[0].message.args[0],
)
assert re.match(
r"\x1b\[31mERROR: Traceback \(most recent call last\):(?s:.)*in raise_error_step(?s:.)*ValueError: Error in step with [12]\n\x1b\[0m",
caught_warnings[1].message.args[0],
)
assert re.match(
r"\x1b\[31mERROR: Received the following error from Worker-[12] - Shutting it down\x1b\[0m",
caught_warnings[2].message.args[0],
)
assert re.match(
r"\x1b\[31mERROR: Traceback \(most recent call last\):(?s:.)*in raise_error_step(?s:.)*ValueError: Error in step with [12]\n\x1b\[0m",
caught_warnings[3].message.args[0],
)
assert (
caught_warnings[4].message.args[0]
== "\x1b[31mERROR: Raising the last exception back to the main process.\x1b[0m"
)