__credits__ = ["Kallinteris-Andreas"] import os import warnings import numpy as np import pytest from gymnasium import utils from gymnasium.envs.mujoco import MujocoEnv from gymnasium.error import Error from gymnasium.spaces import Box from gymnasium.utils.env_checker import check_env class PointEnv(MujocoEnv, utils.EzPickle): """ A simple mujuco env to test third party mujoco env, using the `Gymansium.MujocoEnv` environment API. """ metadata = { "render_modes": [ "human", "rgb_array", "depth_array", ], } def __init__(self, xml_file="point.xml", frame_skip=1, **kwargs): utils.EzPickle.__init__(self, xml_file, frame_skip, **kwargs) MujocoEnv.__init__( self, xml_file, frame_skip=frame_skip, observation_space=None, # needs to be defined after default_camera_config={}, **kwargs, ) self.metadata = { "render_modes": [ "human", "rgb_array", "depth_array", ], "render_fps": int(np.round(1.0 / self.dt)), } obs_size = self.data.qpos.size + self.data.qvel.size self.observation_space = Box( low=-np.inf, high=np.inf, shape=(obs_size,), dtype=np.float64 ) def step(self, action): x_position_before = self.data.qpos[0] self.do_simulation(action, self.frame_skip) x_position_after = self.data.qpos[0] observation = self._get_obs() reward = x_position_after - x_position_before info = {} if self.render_mode == "human": self.render() return observation, reward, False, False, info def _get_obs(self): position = self.data.qpos.flat.copy() velocity = self.data.qvel.flat.copy() return np.concatenate((position, velocity)) def reset_model(self): qpos = self.init_qpos qvel = self.init_qvel self.set_state(qpos, qvel) observation = self._get_obs() return observation def _get_reset_info(self): return {"works": True} CHECK_ENV_IGNORE_WARNINGS = [ f"\x1b[33mWARN: {message}\x1b[0m" for message in [ "A Box observation space minimum value is -infinity. This is probably too low.", "A Box observation space maximum value is infinity. This is probably too high.", "For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.", ] ] @pytest.mark.parametrize("frame_skip", [1, 2, 3, 4, 5]) def test_frame_skip(frame_skip): """verify that custom envs work with different `frame_skip` values""" env = PointEnv(frame_skip=frame_skip) # Test if env adheres to Gym API with warnings.catch_warnings(record=True) as w: check_env(env.unwrapped, skip_render_check=True) env.close() for warning in w: if warning.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS: raise Error(f"Unexpected warning: {warning.message}") def test_xml_file(): """Verify that the loading of a custom XML file works""" relative_path = "./tests/envs/mujoco/assets/walker2d_v5_uneven_feet.xml" env = PointEnv(xml_file=relative_path).unwrapped assert isinstance(env, MujocoEnv) assert env.data.qpos.size == 9 full_path = os.getcwd() + "/tests/envs/mujoco/assets/walker2d_v5_uneven_feet.xml" env = PointEnv(xml_file=full_path).unwrapped assert isinstance(env, MujocoEnv) assert env.data.qpos.size == 9 # note can not test user home path (with '~') because github CI does not have a home folder def test_reset_info(): """Verify that the environment returns info at `reset()`""" env = PointEnv() _, info = env.reset() assert info["works"] is True