mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-28 01:07:11 +00:00
Add check_mujoco_reset_state
(#928)
This commit is contained in:
committed by
GitHub
parent
0037cbc4ab
commit
ceb97f29a6
64
gymnasium/envs/mujoco/utils.py
Normal file
64
gymnasium/envs/mujoco/utils.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
"""A set of MujocoEnv related utilities, mainly for testing purposes.
|
||||||
|
|
||||||
|
Author: @Kallinteris-Andreas
|
||||||
|
"""
|
||||||
|
|
||||||
|
import mujoco
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import gymnasium
|
||||||
|
|
||||||
|
|
||||||
|
def get_state(
|
||||||
|
env: gymnasium.envs.mujoco.MujocoEnv,
|
||||||
|
state_type: mujoco.mjtState = mujoco.mjtState.mjSTATE_PHYSICS,
|
||||||
|
):
|
||||||
|
"""Gets the state of `env`.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
env: Environment whose state to copy, `env.model` & `env.data` must be accessible.
|
||||||
|
state_type: see the [documentation of mjtState](https://mujoco.readthedocs.io/en/stable/APIreference/APItypes.html#mjtstate) most users can use the default for training purposes or `mujoco.mjtState.mjSTATE_INTEGRATION` for validation purposes.
|
||||||
|
"""
|
||||||
|
assert mujoco.__version__ >= "2.3.6", "Feature requires `mujuco>=2.3.6`"
|
||||||
|
|
||||||
|
state = np.empty(mujoco.mj_stateSize(env.unwrapped.model, state_type))
|
||||||
|
mujoco.mj_getState(env.unwrapped.model, env.unwrapped.data, state, state_type)
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def set_state(
|
||||||
|
env: gymnasium.envs.mujoco.MujocoEnv,
|
||||||
|
state: np.ndarray,
|
||||||
|
state_type: mujoco.mjtState = mujoco.mjtState.mjSTATE_PHYSICS,
|
||||||
|
):
|
||||||
|
"""Set the state of `env`.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
env: Environment whose state to set, `env.model` & `env.data` must be accessible.
|
||||||
|
state: State to set (generated from get_state).
|
||||||
|
state_type: see the [documentation of mjtState](https://mujoco.readthedocs.io/en/stable/APIreference/APItypes.html#mjtstate) most users can use the default for training purposes or `mujoco.mjtState.mjSTATE_INTEGRATION` for validation purposes.
|
||||||
|
"""
|
||||||
|
assert mujoco.__version__ >= "2.3.6", "Feature requires `mujuco>=2.3.6`"
|
||||||
|
|
||||||
|
mujoco.mj_setState(
|
||||||
|
env.unwrapped.model,
|
||||||
|
env.unwrapped.data,
|
||||||
|
state,
|
||||||
|
spec=mujoco.mjtState.mjSTATE_PHYSICS,
|
||||||
|
)
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def check_mujoco_reset_state(env: gymnasium.envs.mujoco.MujocoEnv, seed=1234):
|
||||||
|
"""Asserts that `env.reset` properly resets the state (not affected by previous steps), assuming `check_reset_seed` has passed."""
|
||||||
|
env.action_space.seed(seed)
|
||||||
|
action = env.action_space.sample()
|
||||||
|
|
||||||
|
env.reset(seed=seed)
|
||||||
|
first_reset_state = get_state(env, mujoco.mjtState.mjSTATE_INTEGRATION)
|
||||||
|
env.step(action)
|
||||||
|
|
||||||
|
env.reset(seed=seed)
|
||||||
|
second_reset_state = get_state(env, mujoco.mjtState.mjSTATE_INTEGRATION)
|
||||||
|
|
||||||
|
assert np.all(first_reset_state == second_reset_state), "reset is not deterministic"
|
@@ -7,6 +7,7 @@ import pytest
|
|||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.envs.mujoco.mujoco_env import BaseMujocoEnv, MujocoEnv
|
from gymnasium.envs.mujoco.mujoco_env import BaseMujocoEnv, MujocoEnv
|
||||||
|
from gymnasium.envs.mujoco.utils import check_mujoco_reset_state
|
||||||
from gymnasium.error import Error
|
from gymnasium.error import Error
|
||||||
from gymnasium.utils.env_checker import check_env
|
from gymnasium.utils.env_checker import check_env
|
||||||
from gymnasium.utils.env_match import check_environments_match
|
from gymnasium.utils.env_match import check_environments_match
|
||||||
@@ -698,3 +699,11 @@ def test_reset_noise_scale(env_id):
|
|||||||
|
|
||||||
assert np.all(env.data.qpos == env.init_qpos)
|
assert np.all(env.data.qpos == env.init_qpos)
|
||||||
assert np.all(env.data.qvel == env.init_qvel)
|
assert np.all(env.data.qvel == env.init_qvel)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("env_name", ALL_MUJOCO_ENVS)
|
||||||
|
@pytest.mark.parametrize("version", ["v5", "v4"])
|
||||||
|
def test_reset_state(env_name: str, version: str):
|
||||||
|
"""Asserts that `reset()` properly resets the internal state."""
|
||||||
|
env = gym.make(f"{env_name}-{version}")
|
||||||
|
check_mujoco_reset_state(env)
|
||||||
|
Reference in New Issue
Block a user