Files
Gymnasium/gym/envs/mujoco/mujoco_env.py

137 lines
4.2 KiB
Python
Raw Normal View History

2016-04-30 22:47:51 -07:00
import os
2016-04-27 08:00:58 -07:00
2016-04-30 22:47:51 -07:00
from gym import error, spaces
2016-04-27 08:00:58 -07:00
import numpy as np
2016-04-30 22:47:51 -07:00
from os import path
2016-04-27 08:00:58 -07:00
import gym
import six
2016-04-27 08:00:58 -07:00
try:
import mujoco_py
2016-05-10 17:05:04 +02:00
from mujoco_py.mjlib import mjlib
2016-04-27 08:00:58 -07:00
except ImportError as e:
raise error.DependencyNotInstalled("{}. (HINT: you need to install mujoco_py, and also perform the setup instructions here: https://github.com/openai/mujoco-py/.)".format(e))
2016-04-27 08:00:58 -07:00
class MujocoEnv(gym.Env):
2016-04-30 22:47:51 -07:00
"""
Superclass of MuJoCo environments.
"""
2016-04-27 08:00:58 -07:00
def __init__(self, model_path, frame_skip):
if model_path.startswith("/"):
fullpath = model_path
else:
fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path)
2016-04-30 22:47:51 -07:00
if not path.exists(fullpath):
2016-04-27 08:00:58 -07:00
raise IOError("File %s does not exist"%fullpath)
self.frame_skip= frame_skip
self.model = mujoco_py.MjModel(fullpath)
self.data = self.model.data
self.viewer = None
self.metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second' : int(np.round(1.0 / self.dt))
}
2016-04-30 22:47:51 -07:00
self.init_qpos = self.model.data.qpos.ravel().copy()
self.init_qvel = self.model.data.qvel.ravel().copy()
observation, _reward, done, _info = self._step(np.zeros(self.model.nu))
2016-04-30 22:47:51 -07:00
assert not done
self.obs_dim = observation.size
bounds = self.model.actuator_ctrlrange.copy()
low = bounds[:, 0]
high = bounds[:, 1]
self.action_space = spaces.Box(low, high)
high = np.inf*np.ones(self.obs_dim)
low = -high
self.observation_space = spaces.Box(low, high)
2016-04-30 22:47:51 -07:00
# methods to override:
# ----------------------------
def reset_model(self):
"""
Reset the robot degrees of freedom (qpos and qvel).
Implement this in each subclass.
"""
raise NotImplementedError
def viewer_setup(self):
"""
This method is called when the viewer is initialized and after every reset
Optionally implement this method, if you need to tinker with camera position
and so forth.
"""
pass
# -----------------------------
def _reset(self):
2016-05-01 10:57:01 -07:00
mjlib.mj_resetData(self.model.ptr, self.data.ptr)
2016-04-30 22:47:51 -07:00
ob = self.reset_model()
if self.viewer is not None:
self.viewer.autoscale()
self.viewer_setup()
return ob
def set_state(self, qpos, qvel):
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
self.model.data.qpos = qpos
self.model.data.qvel = qvel
self.model._compute_subtree() #pylint: disable=W0212
self.model.forward()
2016-04-27 08:00:58 -07:00
@property
def dt(self):
return self.model.opt.timestep * self.frame_skip
def do_simulation(self, ctrl, n_frames):
self.model.data.ctrl = ctrl
for _ in range(n_frames):
self.model.step()
def _render(self, mode='human', close=False):
if close:
if self.viewer is not None:
self._get_viewer().finish()
self.viewer = None
2016-04-27 08:00:58 -07:00
return
if mode == 'rgb_array':
self._get_viewer().render()
data, width, height = self._get_viewer().get_image()
return np.fromstring(data, dtype='uint8').reshape(height, width, 3)[::-1,:,:]
elif mode == 'human':
2016-04-27 08:00:58 -07:00
self._get_viewer().loop_once()
def _get_viewer(self):
if self.viewer is None:
self.viewer = mujoco_py.MjViewer()
self.viewer.start()
self.viewer.set_model(self.model)
self.viewer_setup()
return self.viewer
def get_body_com(self, body_name):
idx = self.model.body_names.index(six.b(body_name))
2016-04-27 08:00:58 -07:00
return self.model.data.com_subtree[idx]
def get_body_comvel(self, body_name):
idx = self.model.body_names.index(six.b(body_name))
2016-04-27 08:00:58 -07:00
return self.model.body_comvels[idx]
def get_body_xmat(self, body_name):
idx = self.model.body_names.index(six.b(body_name))
2016-04-27 08:00:58 -07:00
return self.model.data.xmat[idx].reshape((3, 3))
2016-04-30 22:47:51 -07:00
def state_vector(self):
2016-04-27 08:00:58 -07:00
return np.concatenate([
self.model.data.qpos.flat,
self.model.data.qvel.flat
])