diff --git a/gym/envs/mujoco/mujoco_env.py b/gym/envs/mujoco/mujoco_env.py index 9df228334..bcfbd2c51 100644 --- a/gym/envs/mujoco/mujoco_env.py +++ b/gym/envs/mujoco/mujoco_env.py @@ -128,6 +128,9 @@ class MujocoEnv(gym.Env): return self.model.opt.timestep * self.frame_skip def do_simulation(self, ctrl, n_frames): + if np.array(ctrl).shape != self.action_space.shape: + raise ValueError("Action dimension mismatch") + self.sim.data.ctrl[:] = ctrl for _ in range(n_frames): self.sim.step() diff --git a/gym/envs/robotics/robot_env.py b/gym/envs/robotics/robot_env.py index f7b496c31..eb0adb2de 100644 --- a/gym/envs/robotics/robot_env.py +++ b/gym/envs/robotics/robot_env.py @@ -70,6 +70,9 @@ class RobotEnv(gym.GoalEnv): return [seed] def step(self, action): + if np.array(action).shape != self.action_space.shape: + raise ValueError("Action dimension mismatch") + action = np.clip(action, self.action_space.low, self.action_space.high) self._set_action(action) self.sim.step() diff --git a/tests/envs/test_action_dim_check.py b/tests/envs/test_action_dim_check.py new file mode 100644 index 000000000..31f3ba59d --- /dev/null +++ b/tests/envs/test_action_dim_check.py @@ -0,0 +1,25 @@ +import pickle + +import pytest + +from gym import envs +from tests.envs.spec_list import skip_mujoco, SKIP_MUJOCO_WARNING_MESSAGE + + +ENVIRONMENT_IDS = ( + "FetchReach-v1", + "HalfCheetah-v2", +) + + +@pytest.mark.skipif(skip_mujoco, reason=SKIP_MUJOCO_WARNING_MESSAGE) +@pytest.mark.parametrize("environment_id", ENVIRONMENT_IDS) +def test_serialize_deserialize(environment_id): + env = envs.make(environment_id) + env.reset() + + with pytest.raises(ValueError, match="Action dimension mismatch"): + env.step([0.1]) + + with pytest.raises(ValueError, match="Action dimension mismatch"): + env.step(0.1)