mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-28 09:17:18 +00:00
65 lines
2.3 KiB
Python
65 lines
2.3 KiB
Python
"""A set of MujocoEnv related utilities, mainly for testing purposes.
|
|
|
|
Author: @Kallinteris-Andreas
|
|
"""
|
|
|
|
import mujoco
|
|
import numpy as np
|
|
|
|
import gymnasium
|
|
|
|
|
|
def get_state(
|
|
env: gymnasium.envs.mujoco.MujocoEnv,
|
|
state_type: mujoco.mjtState = mujoco.mjtState.mjSTATE_PHYSICS,
|
|
):
|
|
"""Gets the state of `env`.
|
|
|
|
Arguments:
|
|
env: Environment whose state to copy, `env.model` & `env.data` must be accessible.
|
|
state_type: see the [documentation of mjtState](https://mujoco.readthedocs.io/en/stable/APIreference/APItypes.html#mjtstate) most users can use the default for training purposes or `mujoco.mjtState.mjSTATE_INTEGRATION` for validation purposes.
|
|
"""
|
|
assert mujoco.__version__ >= "2.3.6", "Feature requires `mujuco>=2.3.6`"
|
|
|
|
state = np.empty(mujoco.mj_stateSize(env.unwrapped.model, state_type))
|
|
mujoco.mj_getState(env.unwrapped.model, env.unwrapped.data, state, state_type)
|
|
return state
|
|
|
|
|
|
def set_state(
|
|
env: gymnasium.envs.mujoco.MujocoEnv,
|
|
state: np.ndarray,
|
|
state_type: mujoco.mjtState = mujoco.mjtState.mjSTATE_PHYSICS,
|
|
):
|
|
"""Set the state of `env`.
|
|
|
|
Arguments:
|
|
env: Environment whose state to set, `env.model` & `env.data` must be accessible.
|
|
state: State to set (generated from get_state).
|
|
state_type: see the [documentation of mjtState](https://mujoco.readthedocs.io/en/stable/APIreference/APItypes.html#mjtstate) most users can use the default for training purposes or `mujoco.mjtState.mjSTATE_INTEGRATION` for validation purposes.
|
|
"""
|
|
assert mujoco.__version__ >= "2.3.6", "Feature requires `mujuco>=2.3.6`"
|
|
|
|
mujoco.mj_setState(
|
|
env.unwrapped.model,
|
|
env.unwrapped.data,
|
|
state,
|
|
spec=mujoco.mjtState.mjSTATE_PHYSICS,
|
|
)
|
|
return state
|
|
|
|
|
|
def check_mujoco_reset_state(env: gymnasium.envs.mujoco.MujocoEnv, seed=1234):
|
|
"""Asserts that `env.reset` properly resets the state (not affected by previous steps), assuming `check_reset_seed` has passed."""
|
|
env.action_space.seed(seed)
|
|
action = env.action_space.sample()
|
|
|
|
env.reset(seed=seed)
|
|
first_reset_state = get_state(env, mujoco.mjtState.mjSTATE_INTEGRATION)
|
|
env.step(action)
|
|
|
|
env.reset(seed=seed)
|
|
second_reset_state = get_state(env, mujoco.mjtState.mjSTATE_INTEGRATION)
|
|
|
|
assert np.all(first_reset_state == second_reset_state), "reset is not deterministic"
|