Add check_mujoco_reset_state (#928)

This commit is contained in:
Kallinteris Andreas
2024-02-19 13:24:59 +02:00
committed by GitHub
parent 0037cbc4ab
commit ceb97f29a6
2 changed files with 73 additions and 0 deletions

View File

@@ -0,0 +1,64 @@
"""A set of MujocoEnv related utilities, mainly for testing purposes.
Author: @Kallinteris-Andreas
"""
import mujoco
import numpy as np
import gymnasium
def get_state(
env: gymnasium.envs.mujoco.MujocoEnv,
state_type: mujoco.mjtState = mujoco.mjtState.mjSTATE_PHYSICS,
):
"""Gets the state of `env`.
Arguments:
env: Environment whose state to copy, `env.model` & `env.data` must be accessible.
state_type: see the [documentation of mjtState](https://mujoco.readthedocs.io/en/stable/APIreference/APItypes.html#mjtstate) most users can use the default for training purposes or `mujoco.mjtState.mjSTATE_INTEGRATION` for validation purposes.
"""
assert mujoco.__version__ >= "2.3.6", "Feature requires `mujuco>=2.3.6`"
state = np.empty(mujoco.mj_stateSize(env.unwrapped.model, state_type))
mujoco.mj_getState(env.unwrapped.model, env.unwrapped.data, state, state_type)
return state
def set_state(
env: gymnasium.envs.mujoco.MujocoEnv,
state: np.ndarray,
state_type: mujoco.mjtState = mujoco.mjtState.mjSTATE_PHYSICS,
):
"""Set the state of `env`.
Arguments:
env: Environment whose state to set, `env.model` & `env.data` must be accessible.
state: State to set (generated from get_state).
state_type: see the [documentation of mjtState](https://mujoco.readthedocs.io/en/stable/APIreference/APItypes.html#mjtstate) most users can use the default for training purposes or `mujoco.mjtState.mjSTATE_INTEGRATION` for validation purposes.
"""
assert mujoco.__version__ >= "2.3.6", "Feature requires `mujuco>=2.3.6`"
mujoco.mj_setState(
env.unwrapped.model,
env.unwrapped.data,
state,
spec=mujoco.mjtState.mjSTATE_PHYSICS,
)
return state
def check_mujoco_reset_state(env: gymnasium.envs.mujoco.MujocoEnv, seed=1234):
"""Asserts that `env.reset` properly resets the state (not affected by previous steps), assuming `check_reset_seed` has passed."""
env.action_space.seed(seed)
action = env.action_space.sample()
env.reset(seed=seed)
first_reset_state = get_state(env, mujoco.mjtState.mjSTATE_INTEGRATION)
env.step(action)
env.reset(seed=seed)
second_reset_state = get_state(env, mujoco.mjtState.mjSTATE_INTEGRATION)
assert np.all(first_reset_state == second_reset_state), "reset is not deterministic"

View File

@@ -7,6 +7,7 @@ import pytest
import gymnasium as gym
from gymnasium.envs.mujoco.mujoco_env import BaseMujocoEnv, MujocoEnv
from gymnasium.envs.mujoco.utils import check_mujoco_reset_state
from gymnasium.error import Error
from gymnasium.utils.env_checker import check_env
from gymnasium.utils.env_match import check_environments_match
@@ -698,3 +699,11 @@ def test_reset_noise_scale(env_id):
assert np.all(env.data.qpos == env.init_qpos)
assert np.all(env.data.qvel == env.init_qvel)
@pytest.mark.parametrize("env_name", ALL_MUJOCO_ENVS)
@pytest.mark.parametrize("version", ["v5", "v4"])
def test_reset_state(env_name: str, version: str):
"""Asserts that `reset()` properly resets the internal state."""
env = gym.make(f"{env_name}-{version}")
check_mujoco_reset_state(env)