Initialize observation spaces and pytest (#2929)

* Remove step initialization for mujoco obs spaces

	* remove step initialization for mujoco obs space

	* pre-commit

pytest obs space mujoco
This commit is contained in:
Rodrigo de Lazcano
2022-06-30 10:59:59 -04:00
committed by GitHub
parent 7f6effbc0d
commit 61a39f41bc
30 changed files with 364 additions and 37 deletions

View File

@@ -2,6 +2,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
@@ -17,8 +18,16 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
observation_space = Box(
low=-np.inf, high=np.inf, shape=(111,), dtype=np.float64
)
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, "ant.xml", 5, mujoco_bindings="mujoco_py", **kwargs self,
"ant.xml",
5,
mujoco_bindings="mujoco_py",
observation_space=observation_space,
**kwargs
) )
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)

View File

@@ -2,6 +2,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
DEFAULT_CAMERA_CONFIG = { DEFAULT_CAMERA_CONFIG = {
"distance": 4.0, "distance": 4.0,
@@ -50,8 +51,22 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
if exclude_current_positions_from_observation:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(111,), dtype=np.float64
)
else:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(113,), dtype=np.float64
)
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, xml_file, 5, mujoco_bindings="mujoco_py", **kwargs self,
xml_file,
5,
mujoco_bindings="mujoco_py",
observation_space=observation_space,
**kwargs
) )
@property @property

View File

@@ -2,6 +2,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
DEFAULT_CAMERA_CONFIG = { DEFAULT_CAMERA_CONFIG = {
"distance": 4.0, "distance": 4.0,
@@ -215,7 +216,19 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
mujoco_env.MujocoEnv.__init__(self, xml_file, 5, **kwargs) obs_shape = 27
if not exclude_current_positions_from_observation:
obs_shape += 2
if use_contact_forces:
obs_shape += 84
observation_space = Box(
low=-np.inf, high=np.inf, shape=(obs_shape,), dtype=np.float64
)
mujoco_env.MujocoEnv.__init__(
self, xml_file, 5, observation_space=observation_space, **kwargs
)
@property @property
def healthy_reward(self): def healthy_reward(self):

View File

@@ -2,6 +2,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle): class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
@@ -17,8 +18,14 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
observation_space = Box(low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64)
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, "half_cheetah.xml", 5, mujoco_bindings="mujoco_py", **kwargs self,
"half_cheetah.xml",
5,
mujoco_bindings="mujoco_py",
observation_space=observation_space,
**kwargs
) )
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)

View File

@@ -4,6 +4,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
DEFAULT_CAMERA_CONFIG = { DEFAULT_CAMERA_CONFIG = {
"distance": 4.0, "distance": 4.0,
@@ -43,8 +44,22 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
if exclude_current_positions_from_observation:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64
)
else:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
)
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, xml_file, 5, mujoco_bindings="mujoco_py", **kwargs self,
xml_file,
5,
mujoco_bindings="mujoco_py",
observation_space=observation_space,
**kwargs
) )
def control_cost(self, action): def control_cost(self, action):

View File

@@ -4,6 +4,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
DEFAULT_CAMERA_CONFIG = { DEFAULT_CAMERA_CONFIG = {
"distance": 4.0, "distance": 4.0,
@@ -162,7 +163,18 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
mujoco_env.MujocoEnv.__init__(self, "half_cheetah.xml", 5, **kwargs) if exclude_current_positions_from_observation:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64
)
else:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
)
mujoco_env.MujocoEnv.__init__(
self, "half_cheetah.xml", 5, observation_space=observation_space, **kwargs
)
def control_cost(self, action): def control_cost(self, action):
control_cost = self._ctrl_cost_weight * np.sum(np.square(action)) control_cost = self._ctrl_cost_weight * np.sum(np.square(action))

View File

@@ -2,6 +2,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle): class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
@@ -17,8 +18,14 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
observation_space = Box(low=-np.inf, high=np.inf, shape=(11,), dtype=np.float64)
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, "hopper.xml", 4, mujoco_bindings="mujoco_py", **kwargs self,
"hopper.xml",
4,
mujoco_bindings="mujoco_py",
observation_space=observation_space,
**kwargs
) )
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)

View File

@@ -4,6 +4,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
DEFAULT_CAMERA_CONFIG = { DEFAULT_CAMERA_CONFIG = {
"trackbodyid": 2, "trackbodyid": 2,
@@ -58,8 +59,22 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
if exclude_current_positions_from_observation:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(11,), dtype=np.float64
)
else:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(12,), dtype=np.float64
)
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, xml_file, 4, mujoco_bindings="mujoco_py", **kwargs self,
xml_file,
4,
mujoco_bindings="mujoco_py",
observation_space=observation_space,
**kwargs
) )
@property @property

View File

@@ -2,6 +2,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
DEFAULT_CAMERA_CONFIG = { DEFAULT_CAMERA_CONFIG = {
"trackbodyid": 2, "trackbodyid": 2,
@@ -180,7 +181,18 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
mujoco_env.MujocoEnv.__init__(self, "hopper.xml", 4, **kwargs) if exclude_current_positions_from_observation:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(11,), dtype=np.float64
)
else:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(12,), dtype=np.float64
)
mujoco_env.MujocoEnv.__init__(
self, "hopper.xml", 4, observation_space=observation_space, **kwargs
)
@property @property
def healthy_reward(self): def healthy_reward(self):

View File

@@ -2,6 +2,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
def mass_center(model, sim): def mass_center(model, sim):
@@ -23,8 +24,16 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
observation_space = Box(
low=-np.inf, high=np.inf, shape=(376,), dtype=np.float64
)
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, "humanoid.xml", 5, mujoco_bindings="mujoco_py", **kwargs self,
"humanoid.xml",
5,
mujoco_bindings="mujoco_py",
observation_space=observation_space,
**kwargs
) )
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)

View File

@@ -2,6 +2,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
DEFAULT_CAMERA_CONFIG = { DEFAULT_CAMERA_CONFIG = {
"trackbodyid": 1, "trackbodyid": 1,
@@ -58,9 +59,22 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self._exclude_current_positions_from_observation = ( self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
if exclude_current_positions_from_observation:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(376,), dtype=np.float64
)
else:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(378,), dtype=np.float64
)
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, xml_file, 5, mujoco_bindings="mujoco_py", **kwargs self,
xml_file,
5,
mujoco_bindings="mujoco_py",
observation_space=observation_space,
**kwargs
) )
@property @property

View File

@@ -2,6 +2,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
DEFAULT_CAMERA_CONFIG = { DEFAULT_CAMERA_CONFIG = {
"trackbodyid": 1, "trackbodyid": 1,
@@ -248,7 +249,18 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
mujoco_env.MujocoEnv.__init__(self, "humanoid.xml", 5, **kwargs) if exclude_current_positions_from_observation:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(376,), dtype=np.float64
)
else:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(378,), dtype=np.float64
)
mujoco_env.MujocoEnv.__init__(
self, "humanoid.xml", 5, observation_space=observation_space, **kwargs
)
@property @property
def healthy_reward(self): def healthy_reward(self):

View File

@@ -2,6 +2,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
class HumanoidStandupEnv(mujoco_env.MujocoEnv, utils.EzPickle): class HumanoidStandupEnv(mujoco_env.MujocoEnv, utils.EzPickle):
@@ -17,8 +18,16 @@ class HumanoidStandupEnv(mujoco_env.MujocoEnv, utils.EzPickle):
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
observation_space = Box(
low=-np.inf, high=np.inf, shape=(376,), dtype=np.float64
)
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, "humanoidstandup.xml", 5, mujoco_bindings="mujoco_py", **kwargs self,
"humanoidstandup.xml",
5,
mujoco_bindings="mujoco_py",
observation_space=observation_space,
**kwargs
) )
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)

View File

@@ -2,6 +2,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
class HumanoidStandupEnv(mujoco_env.MujocoEnv, utils.EzPickle): class HumanoidStandupEnv(mujoco_env.MujocoEnv, utils.EzPickle):
@@ -189,7 +190,16 @@ class HumanoidStandupEnv(mujoco_env.MujocoEnv, utils.EzPickle):
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
mujoco_env.MujocoEnv.__init__(self, "humanoidstandup.xml", 5, **kwargs) observation_space = Box(
low=-np.inf, high=np.inf, shape=(376,), dtype=np.float64
)
mujoco_env.MujocoEnv.__init__(
self,
"humanoidstandup.xml",
5,
observation_space=observation_space,
**kwargs
)
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
def _get_obs(self): def _get_obs(self):

View File

@@ -2,6 +2,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle): class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
@@ -17,11 +18,13 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
observation_space = Box(low=-np.inf, high=np.inf, shape=(11,), dtype=np.float64)
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, self,
"inverted_double_pendulum.xml", "inverted_double_pendulum.xml",
5, 5,
mujoco_bindings="mujoco_py", mujoco_bindings="mujoco_py",
observation_space=observation_space,
**kwargs **kwargs
) )
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)

View File

@@ -2,6 +2,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle): class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
@@ -123,7 +124,14 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
mujoco_env.MujocoEnv.__init__(self, "inverted_double_pendulum.xml", 5, **kwargs) observation_space = Box(low=-np.inf, high=np.inf, shape=(11,), dtype=np.float64)
mujoco_env.MujocoEnv.__init__(
self,
"inverted_double_pendulum.xml",
5,
observation_space=observation_space,
**kwargs
)
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
def step(self, action): def step(self, action):

View File

@@ -2,6 +2,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle): class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
@@ -18,8 +19,14 @@ class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self, **kwargs): def __init__(self, **kwargs):
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
observation_space = Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64)
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, "inverted_pendulum.xml", 2, mujoco_bindings="mujoco_py", **kwargs self,
"inverted_pendulum.xml",
2,
mujoco_bindings="mujoco_py",
observation_space=observation_space,
**kwargs
) )
def step(self, a): def step(self, a):

View File

@@ -2,6 +2,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle): class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
@@ -95,7 +96,14 @@ class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self, **kwargs): def __init__(self, **kwargs):
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
mujoco_env.MujocoEnv.__init__(self, "inverted_pendulum.xml", 2, **kwargs) observation_space = Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64)
mujoco_env.MujocoEnv.__init__(
self,
"inverted_pendulum.xml",
2,
observation_space=observation_space,
**kwargs
)
def step(self, a): def step(self, a):
reward = 1.0 reward = 1.0

View File

@@ -7,6 +7,7 @@ import numpy as np
import gym import gym
from gym import error, logger, spaces from gym import error, logger, spaces
from gym.spaces import Space
from gym.utils.renderer import Renderer from gym.utils.renderer import Renderer
DEFAULT_SIZE = 480 DEFAULT_SIZE = 480
@@ -39,6 +40,7 @@ class MujocoEnv(gym.Env):
self, self,
model_path, model_path,
frame_skip, frame_skip,
observation_space: Space,
render_mode: Optional[str] = None, render_mode: Optional[str] = None,
width: int = DEFAULT_SIZE, width: int = DEFAULT_SIZE,
height: int = DEFAULT_SIZE, height: int = DEFAULT_SIZE,
@@ -120,11 +122,7 @@ class MujocoEnv(gym.Env):
) )
self.renderer = Renderer(self.render_mode, render_frame) self.renderer = Renderer(self.render_mode, render_frame)
action = self.action_space.sample() self.observation_space = observation_space
observation, _reward, done, _info = self.step(action)
assert not done
self._set_observation_space(observation)
def _set_action_space(self): def _set_action_space(self):
bounds = self.model.actuator_ctrlrange.copy().astype(np.float32) bounds = self.model.actuator_ctrlrange.copy().astype(np.float32)
@@ -132,10 +130,6 @@ class MujocoEnv(gym.Env):
self.action_space = spaces.Box(low=low, high=high, dtype=np.float32) self.action_space = spaces.Box(low=low, high=high, dtype=np.float32)
return self.action_space return self.action_space
def _set_observation_space(self, observation):
self.observation_space = convert_observation_to_space(observation)
return self.observation_space
# methods to override: # methods to override:
# ---------------------------- # ----------------------------

View File

@@ -2,6 +2,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle): class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
@@ -18,8 +19,14 @@ class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self, **kwargs): def __init__(self, **kwargs):
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
observation_space = Box(low=-np.inf, high=np.inf, shape=(23,), dtype=np.float64)
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, "pusher.xml", 5, mujoco_bindings="mujoco_py", **kwargs self,
"pusher.xml",
5,
mujoco_bindings="mujoco_py",
observation_space=observation_space,
**kwargs
) )
def step(self, a): def step(self, a):

View File

@@ -2,6 +2,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle): class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
@@ -140,7 +141,10 @@ class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self, **kwargs): def __init__(self, **kwargs):
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
mujoco_env.MujocoEnv.__init__(self, "pusher.xml", 5, **kwargs) observation_space = Box(low=-np.inf, high=np.inf, shape=(23,), dtype=np.float64)
mujoco_env.MujocoEnv.__init__(
self, "pusher.xml", 5, observation_space=observation_space, **kwargs
)
def step(self, a): def step(self, a):
vec_1 = self.get_body_com("object") - self.get_body_com("tips_arm") vec_1 = self.get_body_com("object") - self.get_body_com("tips_arm")

View File

@@ -2,6 +2,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle): class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
@@ -18,8 +19,14 @@ class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self, **kwargs): def __init__(self, **kwargs):
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
observation_space = Box(low=-np.inf, high=np.inf, shape=(11,), dtype=np.float64)
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, "reacher.xml", 2, mujoco_bindings="mujoco_py", **kwargs self,
"reacher.xml",
2,
mujoco_bindings="mujoco_py",
observation_space=observation_space,
**kwargs
) )
def step(self, a): def step(self, a):

View File

@@ -2,6 +2,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle): class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
@@ -130,7 +131,10 @@ class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self, **kwargs): def __init__(self, **kwargs):
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
mujoco_env.MujocoEnv.__init__(self, "reacher.xml", 2, **kwargs) observation_space = Box(low=-np.inf, high=np.inf, shape=(11,), dtype=np.float64)
mujoco_env.MujocoEnv.__init__(
self, "reacher.xml", 2, observation_space=observation_space, **kwargs
)
def step(self, a): def step(self, a):
vec = self.get_body_com("fingertip") - self.get_body_com("target") vec = self.get_body_com("fingertip") - self.get_body_com("target")

View File

@@ -2,6 +2,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle): class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
@@ -17,8 +18,14 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
observation_space = Box(low=-np.inf, high=np.inf, shape=(8,), dtype=np.float64)
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, "swimmer.xml", 4, mujoco_bindings="mujoco_py", **kwargs self,
"swimmer.xml",
4,
mujoco_bindings="mujoco_py",
observation_space=observation_space,
**kwargs
) )
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)

View File

@@ -4,6 +4,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
DEFAULT_CAMERA_CONFIG = {} DEFAULT_CAMERA_CONFIG = {}
@@ -40,8 +41,22 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
if exclude_current_positions_from_observation:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(8,), dtype=np.float64
)
else:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(10,), dtype=np.float64
)
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, xml_file, 4, mujoco_bindings="mujoco_py", **kwargs self,
xml_file,
4,
mujoco_bindings="mujoco_py",
observation_space=observation_space,
**kwargs
) )
def control_cost(self, action): def control_cost(self, action):

View File

@@ -4,6 +4,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
DEFAULT_CAMERA_CONFIG = {} DEFAULT_CAMERA_CONFIG = {}
@@ -152,8 +153,17 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self._exclude_current_positions_from_observation = ( self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
if exclude_current_positions_from_observation:
mujoco_env.MujocoEnv.__init__(self, "swimmer.xml", 4, **kwargs) observation_space = Box(
low=-np.inf, high=np.inf, shape=(8,), dtype=np.float64
)
else:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(10,), dtype=np.float64
)
mujoco_env.MujocoEnv.__init__(
self, "swimmer.xml", 4, observation_space=observation_space, **kwargs
)
def control_cost(self, action): def control_cost(self, action):
control_cost = self._ctrl_cost_weight * np.sum(np.square(action)) control_cost = self._ctrl_cost_weight * np.sum(np.square(action))

View File

@@ -2,6 +2,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle): class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
@@ -17,8 +18,14 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
observation_space = Box(low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64)
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, "walker2d.xml", 4, mujoco_bindings="mujoco_py", **kwargs self,
"walker2d.xml",
4,
mujoco_bindings="mujoco_py",
observation_space=observation_space,
**kwargs
) )
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)

View File

@@ -2,6 +2,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
DEFAULT_CAMERA_CONFIG = { DEFAULT_CAMERA_CONFIG = {
"trackbodyid": 2, "trackbodyid": 2,
@@ -53,8 +54,22 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
if exclude_current_positions_from_observation:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64
)
else:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
)
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, xml_file, 4, mujoco_bindings="mujoco_py", **kwargs self,
xml_file,
4,
mujoco_bindings="mujoco_py",
observation_space=observation_space,
**kwargs
) )
@property @property

View File

@@ -2,6 +2,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import mujoco_env from gym.envs.mujoco import mujoco_env
from gym.spaces import Box
DEFAULT_CAMERA_CONFIG = { DEFAULT_CAMERA_CONFIG = {
"trackbodyid": 2, "trackbodyid": 2,
@@ -182,7 +183,18 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
mujoco_env.MujocoEnv.__init__(self, "walker2d.xml", 4, **kwargs) if exclude_current_positions_from_observation:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64
)
else:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(18,), dtype=np.float64
)
mujoco_env.MujocoEnv.__init__(
self, "walker2d.xml", 4, observation_space=observation_space, **kwargs
)
@property @property
def healthy_reward(self): def healthy_reward(self):

View File

@@ -3,6 +3,7 @@ import pytest
import gym import gym
from gym import envs from gym import envs
from gym.envs.registration import EnvSpec
from tests.envs.utils import mujoco_testing_env_specs from tests.envs.utils import mujoco_testing_env_specs
EPS = 1e-6 EPS = 1e-6
@@ -37,6 +38,65 @@ def verify_environments_match(
break break
EXCLUDE_POS_FROM_OBS = [
"Ant",
"HalfCheetah",
"Hopper",
"Humanoid",
"Swimmer",
"Walker2d",
]
@pytest.mark.parametrize(
"env_spec",
mujoco_testing_env_specs,
ids=[env_spec.id for env_spec in mujoco_testing_env_specs],
)
def test_obs_space_mujoco_environments(env_spec: EnvSpec):
"""Check that the returned observations are contained in the observation space of the environment"""
env = env_spec.make(disable_env_checker=True)
reset_obs = env.reset()
assert env.observation_space.contains(
reset_obs
), f"Obseravtion returned by reset() of {env_spec.id} is not contained in the default observation space {env.observation_space}."
action = env.action_space.sample()
step_obs, _, _, _ = env.step(action)
assert env.observation_space.contains(
step_obs
), f"Obseravtion returned by step(action) of {env_spec.id} is not contained in the default observation space {env.observation_space}."
if env_spec.name in EXCLUDE_POS_FROM_OBS and (
env_spec.version == 4 or env_spec.version == 3
):
env = env_spec.make(
disable_env_checker=True, exclude_current_positions_from_observation=False
)
reset_obs = env.reset()
assert env.observation_space.contains(
reset_obs
), f"Obseravtion of {env_spec.id} is not contained in the default observation space {env.observation_space} when excluding current position from observation."
step_obs, _, _, _ = env.step(action)
assert env.observation_space.contains(
step_obs
), f"Obseravtion returned by step(action) of {env_spec.id} is not contained in the default observation space {env.observation_space} when excluding current position from observation."
# Ant-v4 has the option of including contact forces in the observation space with the use_contact_forces argument
if env_spec.name == "Ant" and env_spec.version == 4:
env = env_spec.make(disable_env_checker=True, use_contact_forces=True)
reset_obs = env.reset()
assert env.observation_space.contains(
reset_obs
), f"Obseravtion of {env_spec.id} is not contained in the default observation space {env.observation_space} when using contact forces."
step_obs, _, _, _ = env.step(action)
assert env.observation_space.contains(
step_obs
), f"Obseravtion returned by step(action) of {env_spec.id} is not contained in the default observation space {env.observation_space} when using contact forces."
MUJOCO_V2_V3_ENVS = [ MUJOCO_V2_V3_ENVS = [
spec.name spec.name
for spec in mujoco_testing_env_specs for spec in mujoco_testing_env_specs