Fix action dimension check bugs (#2469)

* fix action dimension check bugs

* black codes, add test function

* clear codes for simplicity

* update check mujoco install
This commit is contained in:
Minghuan Liu
2021-11-18 07:11:40 +08:00
committed by GitHub
parent 01cc8a3a16
commit 01b4519d9d
3 changed files with 31 additions and 0 deletions

View File

@@ -128,6 +128,9 @@ class MujocoEnv(gym.Env):
return self.model.opt.timestep * self.frame_skip return self.model.opt.timestep * self.frame_skip
def do_simulation(self, ctrl, n_frames): def do_simulation(self, ctrl, n_frames):
if np.array(ctrl).shape != self.action_space.shape:
raise ValueError("Action dimension mismatch")
self.sim.data.ctrl[:] = ctrl self.sim.data.ctrl[:] = ctrl
for _ in range(n_frames): for _ in range(n_frames):
self.sim.step() self.sim.step()

View File

@@ -70,6 +70,9 @@ class RobotEnv(gym.GoalEnv):
return [seed] return [seed]
def step(self, action): def step(self, action):
if np.array(action).shape != self.action_space.shape:
raise ValueError("Action dimension mismatch")
action = np.clip(action, self.action_space.low, self.action_space.high) action = np.clip(action, self.action_space.low, self.action_space.high)
self._set_action(action) self._set_action(action)
self.sim.step() self.sim.step()

View File

@@ -0,0 +1,25 @@
import pickle
import pytest
from gym import envs
from tests.envs.spec_list import skip_mujoco, SKIP_MUJOCO_WARNING_MESSAGE
ENVIRONMENT_IDS = (
"FetchReach-v1",
"HalfCheetah-v2",
)
@pytest.mark.skipif(skip_mujoco, reason=SKIP_MUJOCO_WARNING_MESSAGE)
@pytest.mark.parametrize("environment_id", ENVIRONMENT_IDS)
def test_serialize_deserialize(environment_id):
env = envs.make(environment_id)
env.reset()
with pytest.raises(ValueError, match="Action dimension mismatch"):
env.step([0.1])
with pytest.raises(ValueError, match="Action dimension mismatch"):
env.step(0.1)