2023-06-05 13:09:04 +03:00
__credits__ = [ " Kallinteris-Andreas " ]
2023-06-06 16:41:26 +03:00
import os
2023-06-05 13:09:04 +03:00
import warnings
import numpy as np
import pytest
from gymnasium import utils
from gymnasium . envs . mujoco import MujocoEnv
from gymnasium . error import Error
from gymnasium . spaces import Box
2023-06-06 16:41:26 +03:00
from gymnasium . utils . env_checker import check_env
2023-06-05 13:09:04 +03:00
class PointEnv ( MujocoEnv , utils . EzPickle ) :
"""
A simple mujuco env to test third party mujoco env , using the ` Gymansium . MujocoEnv ` environment API .
"""
metadata = {
" render_modes " : [
" human " ,
" rgb_array " ,
" depth_array " ,
] ,
}
def __init__ ( self , xml_file = " point.xml " , frame_skip = 1 , * * kwargs ) :
utils . EzPickle . __init__ ( self , xml_file , frame_skip , * * kwargs )
MujocoEnv . __init__ (
self ,
xml_file ,
frame_skip = frame_skip ,
observation_space = None , # needs to be defined after
default_camera_config = { } ,
* * kwargs ,
)
self . metadata = {
" render_modes " : [
" human " ,
" rgb_array " ,
" depth_array " ,
] ,
" render_fps " : int ( np . round ( 1.0 / self . dt ) ) ,
}
obs_size = self . data . qpos . size + self . data . qvel . size
self . observation_space = Box (
low = - np . inf , high = np . inf , shape = ( obs_size , ) , dtype = np . float64
)
def step ( self , action ) :
x_position_before = self . data . qpos [ 0 ]
self . do_simulation ( action , self . frame_skip )
x_position_after = self . data . qpos [ 0 ]
observation = self . _get_obs ( )
reward = x_position_after - x_position_before
info = { }
if self . render_mode == " human " :
self . render ( )
return observation , reward , False , False , info
def _get_obs ( self ) :
position = self . data . qpos . flat . copy ( )
velocity = self . data . qvel . flat . copy ( )
return np . concatenate ( ( position , velocity ) )
def reset_model ( self ) :
qpos = self . init_qpos
qvel = self . init_qvel
self . set_state ( qpos , qvel )
observation = self . _get_obs ( )
return observation
CHECK_ENV_IGNORE_WARNINGS = [
f " \x1b [33mWARN: { message } \x1b [0m "
for message in [
" A Box observation space minimum value is -infinity. This is probably too low. " ,
" A Box observation space maximum value is -infinity. This is probably too high. " ,
" For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information. " ,
]
]
@pytest.mark.parametrize ( " frame_skip " , [ 1 , 2 , 3 , 4 , 5 ] )
def test_frame_skip ( frame_skip ) :
""" verify that custom envs work with different `frame_skip` values """
env = PointEnv ( frame_skip = frame_skip )
# Test if env adheres to Gym API
with warnings . catch_warnings ( record = True ) as w :
2023-06-06 16:41:26 +03:00
check_env ( env . unwrapped , skip_render_check = True )
2023-06-05 13:09:04 +03:00
env . close ( )
for warning in w :
if warning . message . args [ 0 ] not in CHECK_ENV_IGNORE_WARNINGS :
raise Error ( f " Unexpected warning: { warning . message } " )
2023-06-06 16:41:26 +03:00
def test_xml_file ( ) :
""" Verify that the loading of a custom XML file works """
relative_path = " ./tests/envs/mujoco/assets/walker2d_v5_uneven_feet.xml "
env = PointEnv ( xml_file = relative_path )
assert env . unwrapped . data . qpos . size == 9
full_path = os . getcwd ( ) + " /tests/envs/mujoco/assets/walker2d_v5_uneven_feet.xml "
env = PointEnv ( xml_file = full_path )
assert env . unwrapped . data . qpos . size == 9
# note can not test user home path (with '~') because github CI does not have a home folder