Files
Gymnasium/tests/envs/test_mujoco_custom_env.py

105 lines
3.1 KiB
Python
Raw Normal View History

__credits__ = ["Kallinteris-Andreas"]
import warnings
import numpy as np
import pytest
import gymnasium as gym
from gymnasium import utils
from gymnasium.envs.mujoco import MujocoEnv
from gymnasium.error import Error
from gymnasium.spaces import Box
class PointEnv(MujocoEnv, utils.EzPickle):
"""
A simple mujuco env to test third party mujoco env, using the `Gymansium.MujocoEnv` environment API.
"""
metadata = {
"render_modes": [
"human",
"rgb_array",
"depth_array",
],
}
def __init__(self, xml_file="point.xml", frame_skip=1, **kwargs):
utils.EzPickle.__init__(self, xml_file, frame_skip, **kwargs)
MujocoEnv.__init__(
self,
xml_file,
frame_skip=frame_skip,
observation_space=None, # needs to be defined after
default_camera_config={},
**kwargs,
)
self.metadata = {
"render_modes": [
"human",
"rgb_array",
"depth_array",
],
"render_fps": int(np.round(1.0 / self.dt)),
}
obs_size = self.data.qpos.size + self.data.qvel.size
self.observation_space = Box(
low=-np.inf, high=np.inf, shape=(obs_size,), dtype=np.float64
)
def step(self, action):
x_position_before = self.data.qpos[0]
self.do_simulation(action, self.frame_skip)
x_position_after = self.data.qpos[0]
observation = self._get_obs()
reward = x_position_after - x_position_before
info = {}
if self.render_mode == "human":
self.render()
return observation, reward, False, False, info
def _get_obs(self):
position = self.data.qpos.flat.copy()
velocity = self.data.qvel.flat.copy()
return np.concatenate((position, velocity))
def reset_model(self):
qpos = self.init_qpos
qvel = self.init_qvel
self.set_state(qpos, qvel)
observation = self._get_obs()
return observation
CHECK_ENV_IGNORE_WARNINGS = [
f"\x1b[33mWARN: {message}\x1b[0m"
for message in [
"A Box observation space minimum value is -infinity. This is probably too low.",
"A Box observation space maximum value is -infinity. This is probably too high.",
"For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.",
]
]
@pytest.mark.parametrize("frame_skip", [1, 2, 3, 4, 5])
def test_frame_skip(frame_skip):
"""verify that custom envs work with different `frame_skip` values"""
env = PointEnv(frame_skip=frame_skip)
# Test if env adheres to Gym API
with warnings.catch_warnings(record=True) as w:
gym.utils.env_checker.check_env(env.unwrapped, skip_render_check=True)
env.close()
for warning in w:
if warning.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS:
raise Error(f"Unexpected warning: {warning.message}")