Add stacktrace reporting to AsyncVectorEnv (#1119)

This commit is contained in:
Mark Towers
2024-07-15 15:53:11 +01:00
committed by GitHub
parent 020a7442c6
commit 992638e120
3 changed files with 99 additions and 31 deletions

View File

@@ -1,6 +1,7 @@
"""Test the `SyncVectorEnv` implementation."""
import re
import warnings
from multiprocessing import TimeoutError
import numpy as np
@@ -13,6 +14,7 @@ from gymnasium.error import (
)
from gymnasium.spaces import Box, Discrete, MultiDiscrete, Tuple
from gymnasium.vector import AsyncVectorEnv
from tests.testing_env import GenericTestEnv
from tests.vector.testing_utils import (
CustomSpace,
make_custom_space_env,
@@ -345,3 +347,90 @@ def test_custom_space_async_vector_env_shared_memory():
with pytest.raises(ValueError):
env = AsyncVectorEnv(env_fns, shared_memory=True)
env.close(terminate=True)
def raise_error_reset(self, seed, options):
super(GenericTestEnv, self).reset(seed=seed, options=options)
if seed == 1:
raise ValueError("Error in reset")
return self.observation_space.sample(), {}
def raise_error_step(self, action):
if action >= 1:
raise ValueError(f"Error in step with {action}")
return self.observation_space.sample(), 0, False, False, {}
def test_async_vector_subenv_error():
envs = AsyncVectorEnv(
[
lambda: GenericTestEnv(
reset_func=raise_error_reset, step_func=raise_error_step
)
]
* 2
)
with warnings.catch_warnings(record=True) as caught_warnings:
envs.reset(seed=[0, 0])
assert len(caught_warnings) == 0
with warnings.catch_warnings(record=True) as caught_warnings:
with pytest.raises(ValueError, match="Error in reset"):
envs.reset(seed=[1, 0])
envs.close()
assert len(caught_warnings) == 3
assert (
"Received the following error from Worker-0 - Shutting it down"
in caught_warnings[0].message.args[0]
)
assert (
'in raise_error_reset\n raise ValueError("Error in reset")\nValueError: Error in reset'
in caught_warnings[1].message.args[0]
)
assert (
caught_warnings[2].message.args[0]
== "\x1b[31mERROR: Raising the last exception back to the main process.\x1b[0m"
)
envs = AsyncVectorEnv(
[
lambda: GenericTestEnv(
reset_func=raise_error_reset, step_func=raise_error_step
)
]
* 3
)
with warnings.catch_warnings(record=True) as caught_warnings:
with pytest.raises(ValueError, match="Error in step"):
envs.step([0, 1, 2])
envs.close()
assert len(caught_warnings) == 5
# due to variance in the step time, the order of warnings is random
assert re.match(
r"\x1b\[31mERROR: Received the following error from Worker-[12] - Shutting it down\x1b\[0m",
caught_warnings[0].message.args[0],
)
assert re.match(
r"\x1b\[31mERROR: Traceback \(most recent call last\):(?s:.)*in raise_error_step(?s:.)*ValueError: Error in step with [12]\n\x1b\[0m",
caught_warnings[1].message.args[0],
)
assert re.match(
r"\x1b\[31mERROR: Received the following error from Worker-[12] - Shutting it down\x1b\[0m",
caught_warnings[2].message.args[0],
)
assert re.match(
r"\x1b\[31mERROR: Traceback \(most recent call last\):(?s:.)*in raise_error_step(?s:.)*ValueError: Error in step with [12]\n\x1b\[0m",
caught_warnings[3].message.args[0],
)
assert (
caught_warnings[4].message.args[0]
== "\x1b[31mERROR: Raising the last exception back to the main process.\x1b[0m"
)