2016-04-30 22:47:51 -07:00
import os
2016-04-27 08:00:58 -07:00
2016-04-30 22:47:51 -07:00
from gym import error , spaces
2016-04-27 08:00:58 -07:00
import numpy as np
2016-04-30 22:47:51 -07:00
from os import path
2016-04-27 08:00:58 -07:00
import gym
2016-04-28 22:32:17 -07:00
import six
2016-04-27 08:00:58 -07:00
try :
import mujoco_py
2016-05-10 17:05:04 +02:00
from mujoco_py . mjlib import mjlib
2016-04-27 08:00:58 -07:00
except ImportError as e :
2016-04-28 14:01:24 +02:00
raise error . DependencyNotInstalled ( " {} . (HINT: you need to install mujoco_py, and also perform the setup instructions here: https://github.com/openai/mujoco-py/.) " . format ( e ) )
2016-04-27 08:00:58 -07:00
class MujocoEnv ( gym . Env ) :
2016-04-30 22:47:51 -07:00
"""
Superclass of MuJoCo environments .
"""
2016-04-27 08:00:58 -07:00
def __init__ ( self , model_path , frame_skip ) :
if model_path . startswith ( " / " ) :
fullpath = model_path
else :
fullpath = os . path . join ( os . path . dirname ( __file__ ) , " assets " , model_path )
2016-04-30 22:47:51 -07:00
if not path . exists ( fullpath ) :
2016-04-27 08:00:58 -07:00
raise IOError ( " File %s does not exist " % fullpath )
self . frame_skip = frame_skip
self . model = mujoco_py . MjModel ( fullpath )
self . data = self . model . data
self . viewer = None
self . metadata = {
' render.modes ' : [ ' human ' , ' rgb_array ' ] ,
' video.frames_per_second ' : int ( np . round ( 1.0 / self . dt ) )
}
2016-04-30 22:47:51 -07:00
self . init_qpos = self . model . data . qpos . ravel ( ) . copy ( )
self . init_qvel = self . model . data . qvel . ravel ( ) . copy ( )
2016-05-03 22:20:41 -04:00
observation , _reward , done , _info = self . _step ( np . zeros ( self . model . nu ) )
2016-04-30 22:47:51 -07:00
assert not done
self . obs_dim = observation . size
bounds = self . model . actuator_ctrlrange . copy ( )
low = bounds [ : , 0 ]
high = bounds [ : , 1 ]
self . action_space = spaces . Box ( low , high )
high = np . inf * np . ones ( self . obs_dim )
low = - high
2016-05-03 22:20:41 -04:00
self . observation_space = spaces . Box ( low , high )
2016-04-30 22:47:51 -07:00
# methods to override:
# ----------------------------
def reset_model ( self ) :
"""
Reset the robot degrees of freedom ( qpos and qvel ) .
Implement this in each subclass .
"""
raise NotImplementedError
def viewer_setup ( self ) :
"""
This method is called when the viewer is initialized and after every reset
Optionally implement this method , if you need to tinker with camera position
and so forth .
"""
pass
# -----------------------------
def _reset ( self ) :
2016-05-01 10:57:01 -07:00
mjlib . mj_resetData ( self . model . ptr , self . data . ptr )
2016-04-30 22:47:51 -07:00
ob = self . reset_model ( )
if self . viewer is not None :
self . viewer . autoscale ( )
self . viewer_setup ( )
return ob
def set_state ( self , qpos , qvel ) :
assert qpos . shape == ( self . model . nq , ) and qvel . shape == ( self . model . nv , )
self . model . data . qpos = qpos
self . model . data . qvel = qvel
self . model . _compute_subtree ( ) #pylint: disable=W0212
self . model . forward ( )
2016-04-27 08:00:58 -07:00
@property
def dt ( self ) :
return self . model . opt . timestep * self . frame_skip
def do_simulation ( self , ctrl , n_frames ) :
self . model . data . ctrl = ctrl
for _ in range ( n_frames ) :
self . model . step ( )
def _render ( self , mode = ' human ' , close = False ) :
if close :
2016-04-27 09:41:15 -07:00
if self . viewer is not None :
self . _get_viewer ( ) . finish ( )
2016-05-15 17:22:38 -07:00
self . viewer = None
2016-04-27 08:00:58 -07:00
return
if mode == ' rgb_array ' :
self . _get_viewer ( ) . render ( )
data , width , height = self . _get_viewer ( ) . get_image ( )
return np . fromstring ( data , dtype = ' uint8 ' ) . reshape ( height , width , 3 ) [ : : - 1 , : , : ]
2016-05-09 20:51:04 -04:00
elif mode == ' human ' :
2016-04-27 08:00:58 -07:00
self . _get_viewer ( ) . loop_once ( )
def _get_viewer ( self ) :
if self . viewer is None :
self . viewer = mujoco_py . MjViewer ( )
self . viewer . start ( )
self . viewer . set_model ( self . model )
self . viewer_setup ( )
return self . viewer
def get_body_com ( self , body_name ) :
2016-04-28 22:32:17 -07:00
idx = self . model . body_names . index ( six . b ( body_name ) )
2016-04-27 08:00:58 -07:00
return self . model . data . com_subtree [ idx ]
def get_body_comvel ( self , body_name ) :
2016-04-28 22:32:17 -07:00
idx = self . model . body_names . index ( six . b ( body_name ) )
2016-04-27 08:00:58 -07:00
return self . model . body_comvels [ idx ]
def get_body_xmat ( self , body_name ) :
2016-04-28 22:32:17 -07:00
idx = self . model . body_names . index ( six . b ( body_name ) )
2016-04-27 08:00:58 -07:00
return self . model . data . xmat [ idx ] . reshape ( ( 3 , 3 ) )
2016-04-30 22:47:51 -07:00
def state_vector ( self ) :
2016-04-27 08:00:58 -07:00
return np . concatenate ( [
self . model . data . qpos . flat ,
self . model . data . qvel . flat
] )