mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 05:44:31 +00:00
[Bugfix] fixed step()
function in AsyncVectorEnv
from hanging forever (#1419)
This commit is contained in:
@@ -410,7 +410,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
)
|
||||
|
||||
iter_actions = iterate(self.action_space, actions)
|
||||
for pipe, action in zip(self.parent_pipes, iter_actions):
|
||||
for pipe, action in zip(self.parent_pipes, iter_actions, strict=True):
|
||||
pipe.send(("step", action))
|
||||
self._state = AsyncState.WAITING_STEP
|
||||
|
||||
|
@@ -247,7 +247,7 @@ class SyncVectorEnv(VectorEnv):
|
||||
actions = iterate(self.action_space, actions)
|
||||
|
||||
infos = {}
|
||||
for i, action in enumerate(actions):
|
||||
for i, (action, _) in enumerate(zip(actions, self.envs, strict=True)):
|
||||
if self.autoreset_mode == AutoresetMode.NEXT_STEP:
|
||||
if self._autoreset_envs[i]:
|
||||
self._env_obs[i], env_info = self.envs[i].reset()
|
||||
|
@@ -317,3 +317,49 @@ def test_partial_reset_failure(vectoriser):
|
||||
),
|
||||
):
|
||||
envs.reset(options={"reset_mask": np.array([1.0, 1.0, 0.0])})
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"vectoriser",
|
||||
[
|
||||
SyncVectorEnv,
|
||||
AsyncVectorEnv,
|
||||
partial(AsyncVectorEnv, shared_memory=False),
|
||||
],
|
||||
ids=["Sync", "Async(shared_memory=True)", "Async(shared_memory=False)"],
|
||||
)
|
||||
def test_action_count_compatibility(vectoriser):
|
||||
"""Test that the number of actions is compatible with the number of environments."""
|
||||
num_envs = 4
|
||||
envs = vectoriser(
|
||||
[lambda: gym.make("CartPole-v1") for _ in range(num_envs)],
|
||||
autoreset_mode=AutoresetMode.DISABLED,
|
||||
)
|
||||
|
||||
# Reset the environment
|
||||
envs.reset()
|
||||
|
||||
# Test correct number of actions (should work)
|
||||
correct_actions = envs.action_space.sample()
|
||||
assert len(correct_actions) == num_envs
|
||||
|
||||
# Test with actions that match the number of environments
|
||||
obs, rewards, terminations, truncations, infos = envs.step(correct_actions)
|
||||
assert len(obs) == num_envs
|
||||
assert len(rewards) == num_envs
|
||||
assert len(terminations) == num_envs
|
||||
assert len(truncations) == num_envs
|
||||
|
||||
# Test with too few actions (should raise error)
|
||||
with pytest.raises(ValueError):
|
||||
envs.step(correct_actions[: num_envs - 1])
|
||||
|
||||
# Test with too many actions (should raise error)
|
||||
with pytest.raises(ValueError):
|
||||
envs.step(np.concatenate([correct_actions, correct_actions[:1]]))
|
||||
|
||||
# Test with scalar action (should raise error for vector env)
|
||||
with pytest.raises(TypeError):
|
||||
envs.step(0)
|
||||
|
||||
envs.close()
|
||||
|
Reference in New Issue
Block a user