diff --git a/gymnasium/vector/async_vector_env.py b/gymnasium/vector/async_vector_env.py index 78584636d..b6195298d 100644 --- a/gymnasium/vector/async_vector_env.py +++ b/gymnasium/vector/async_vector_env.py @@ -410,7 +410,7 @@ class AsyncVectorEnv(VectorEnv): ) iter_actions = iterate(self.action_space, actions) - for pipe, action in zip(self.parent_pipes, iter_actions): + for pipe, action in zip(self.parent_pipes, iter_actions, strict=True): pipe.send(("step", action)) self._state = AsyncState.WAITING_STEP diff --git a/gymnasium/vector/sync_vector_env.py b/gymnasium/vector/sync_vector_env.py index cb1870d96..1f652b23f 100644 --- a/gymnasium/vector/sync_vector_env.py +++ b/gymnasium/vector/sync_vector_env.py @@ -247,7 +247,7 @@ class SyncVectorEnv(VectorEnv): actions = iterate(self.action_space, actions) infos = {} - for i, action in enumerate(actions): + for i, (action, _) in enumerate(zip(actions, self.envs, strict=True)): if self.autoreset_mode == AutoresetMode.NEXT_STEP: if self._autoreset_envs[i]: self._env_obs[i], env_info = self.envs[i].reset() diff --git a/tests/vector/test_vector_env.py b/tests/vector/test_vector_env.py index a88bfe5b0..cb29737b0 100644 --- a/tests/vector/test_vector_env.py +++ b/tests/vector/test_vector_env.py @@ -317,3 +317,49 @@ def test_partial_reset_failure(vectoriser): ), ): envs.reset(options={"reset_mask": np.array([1.0, 1.0, 0.0])}) + + +@pytest.mark.parametrize( + "vectoriser", + [ + SyncVectorEnv, + AsyncVectorEnv, + partial(AsyncVectorEnv, shared_memory=False), + ], + ids=["Sync", "Async(shared_memory=True)", "Async(shared_memory=False)"], +) +def test_action_count_compatibility(vectoriser): + """Test that the number of actions is compatible with the number of environments.""" + num_envs = 4 + envs = vectoriser( + [lambda: gym.make("CartPole-v1") for _ in range(num_envs)], + autoreset_mode=AutoresetMode.DISABLED, + ) + + # Reset the environment + envs.reset() + + # Test correct number of actions (should work) + correct_actions = envs.action_space.sample() + assert len(correct_actions) == num_envs + + # Test with actions that match the number of environments + obs, rewards, terminations, truncations, infos = envs.step(correct_actions) + assert len(obs) == num_envs + assert len(rewards) == num_envs + assert len(terminations) == num_envs + assert len(truncations) == num_envs + + # Test with too few actions (should raise error) + with pytest.raises(ValueError): + envs.step(correct_actions[: num_envs - 1]) + + # Test with too many actions (should raise error) + with pytest.raises(ValueError): + envs.step(np.concatenate([correct_actions, correct_actions[:1]])) + + # Test with scalar action (should raise error for vector env) + with pytest.raises(TypeError): + envs.step(0) + + envs.close()