[Bugfix] fixed step() function in AsyncVectorEnv from hanging forever (#1419)

This commit is contained in:
Matin Moezzi
2025-07-22 17:45:23 -04:00
committed by GitHub
parent 58c6b7b49d
commit ad23dfbbe2
3 changed files with 48 additions and 2 deletions

View File

@@ -410,7 +410,7 @@ class AsyncVectorEnv(VectorEnv):
)
iter_actions = iterate(self.action_space, actions)
for pipe, action in zip(self.parent_pipes, iter_actions):
for pipe, action in zip(self.parent_pipes, iter_actions, strict=True):
pipe.send(("step", action))
self._state = AsyncState.WAITING_STEP

View File

@@ -247,7 +247,7 @@ class SyncVectorEnv(VectorEnv):
actions = iterate(self.action_space, actions)
infos = {}
for i, action in enumerate(actions):
for i, (action, _) in enumerate(zip(actions, self.envs, strict=True)):
if self.autoreset_mode == AutoresetMode.NEXT_STEP:
if self._autoreset_envs[i]:
self._env_obs[i], env_info = self.envs[i].reset()

View File

@@ -317,3 +317,49 @@ def test_partial_reset_failure(vectoriser):
),
):
envs.reset(options={"reset_mask": np.array([1.0, 1.0, 0.0])})
@pytest.mark.parametrize(
"vectoriser",
[
SyncVectorEnv,
AsyncVectorEnv,
partial(AsyncVectorEnv, shared_memory=False),
],
ids=["Sync", "Async(shared_memory=True)", "Async(shared_memory=False)"],
)
def test_action_count_compatibility(vectoriser):
"""Test that the number of actions is compatible with the number of environments."""
num_envs = 4
envs = vectoriser(
[lambda: gym.make("CartPole-v1") for _ in range(num_envs)],
autoreset_mode=AutoresetMode.DISABLED,
)
# Reset the environment
envs.reset()
# Test correct number of actions (should work)
correct_actions = envs.action_space.sample()
assert len(correct_actions) == num_envs
# Test with actions that match the number of environments
obs, rewards, terminations, truncations, infos = envs.step(correct_actions)
assert len(obs) == num_envs
assert len(rewards) == num_envs
assert len(terminations) == num_envs
assert len(truncations) == num_envs
# Test with too few actions (should raise error)
with pytest.raises(ValueError):
envs.step(correct_actions[: num_envs - 1])
# Test with too many actions (should raise error)
with pytest.raises(ValueError):
envs.step(np.concatenate([correct_actions, correct_actions[:1]]))
# Test with scalar action (should raise error for vector env)
with pytest.raises(TypeError):
envs.step(0)
envs.close()