mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-30 09:55:39 +00:00
Add stacktrace reporting to AsyncVectorEnv
(#1119)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""Test the `SyncVectorEnv` implementation."""
|
||||
|
||||
import re
|
||||
import warnings
|
||||
from multiprocessing import TimeoutError
|
||||
|
||||
import numpy as np
|
||||
@@ -13,6 +14,7 @@ from gymnasium.error import (
|
||||
)
|
||||
from gymnasium.spaces import Box, Discrete, MultiDiscrete, Tuple
|
||||
from gymnasium.vector import AsyncVectorEnv
|
||||
from tests.testing_env import GenericTestEnv
|
||||
from tests.vector.testing_utils import (
|
||||
CustomSpace,
|
||||
make_custom_space_env,
|
||||
@@ -345,3 +347,90 @@ def test_custom_space_async_vector_env_shared_memory():
|
||||
with pytest.raises(ValueError):
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=True)
|
||||
env.close(terminate=True)
|
||||
|
||||
|
||||
def raise_error_reset(self, seed, options):
|
||||
super(GenericTestEnv, self).reset(seed=seed, options=options)
|
||||
if seed == 1:
|
||||
raise ValueError("Error in reset")
|
||||
return self.observation_space.sample(), {}
|
||||
|
||||
|
||||
def raise_error_step(self, action):
|
||||
if action >= 1:
|
||||
raise ValueError(f"Error in step with {action}")
|
||||
|
||||
return self.observation_space.sample(), 0, False, False, {}
|
||||
|
||||
|
||||
def test_async_vector_subenv_error():
|
||||
envs = AsyncVectorEnv(
|
||||
[
|
||||
lambda: GenericTestEnv(
|
||||
reset_func=raise_error_reset, step_func=raise_error_step
|
||||
)
|
||||
]
|
||||
* 2
|
||||
)
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
envs.reset(seed=[0, 0])
|
||||
assert len(caught_warnings) == 0
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
with pytest.raises(ValueError, match="Error in reset"):
|
||||
envs.reset(seed=[1, 0])
|
||||
|
||||
envs.close()
|
||||
|
||||
assert len(caught_warnings) == 3
|
||||
assert (
|
||||
"Received the following error from Worker-0 - Shutting it down"
|
||||
in caught_warnings[0].message.args[0]
|
||||
)
|
||||
assert (
|
||||
'in raise_error_reset\n raise ValueError("Error in reset")\nValueError: Error in reset'
|
||||
in caught_warnings[1].message.args[0]
|
||||
)
|
||||
assert (
|
||||
caught_warnings[2].message.args[0]
|
||||
== "\x1b[31mERROR: Raising the last exception back to the main process.\x1b[0m"
|
||||
)
|
||||
|
||||
envs = AsyncVectorEnv(
|
||||
[
|
||||
lambda: GenericTestEnv(
|
||||
reset_func=raise_error_reset, step_func=raise_error_step
|
||||
)
|
||||
]
|
||||
* 3
|
||||
)
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
with pytest.raises(ValueError, match="Error in step"):
|
||||
envs.step([0, 1, 2])
|
||||
|
||||
envs.close()
|
||||
|
||||
assert len(caught_warnings) == 5
|
||||
# due to variance in the step time, the order of warnings is random
|
||||
assert re.match(
|
||||
r"\x1b\[31mERROR: Received the following error from Worker-[12] - Shutting it down\x1b\[0m",
|
||||
caught_warnings[0].message.args[0],
|
||||
)
|
||||
assert re.match(
|
||||
r"\x1b\[31mERROR: Traceback \(most recent call last\):(?s:.)*in raise_error_step(?s:.)*ValueError: Error in step with [12]\n\x1b\[0m",
|
||||
caught_warnings[1].message.args[0],
|
||||
)
|
||||
assert re.match(
|
||||
r"\x1b\[31mERROR: Received the following error from Worker-[12] - Shutting it down\x1b\[0m",
|
||||
caught_warnings[2].message.args[0],
|
||||
)
|
||||
assert re.match(
|
||||
r"\x1b\[31mERROR: Traceback \(most recent call last\):(?s:.)*in raise_error_step(?s:.)*ValueError: Error in step with [12]\n\x1b\[0m",
|
||||
caught_warnings[3].message.args[0],
|
||||
)
|
||||
assert (
|
||||
caught_warnings[4].message.args[0]
|
||||
== "\x1b[31mERROR: Raising the last exception back to the main process.\x1b[0m"
|
||||
)
|
||||
|
Reference in New Issue
Block a user