mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-30 17:57:30 +00:00
Fix unpickling Box2D and MuJoCo envs (#3025)
* Try to fix car racing unpickling * Fix EzPickle for BipedalWalker and LunarLander * Shamelessly steal the pickle-unpickle test from Mark, with slight modifications * CarRacing EzPickle fix * Mujoco ezpickle fix
This commit is contained in:
committed by
GitHub
parent
f54319e742
commit
51c2026f19
@@ -169,7 +169,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
}
|
||||
|
||||
def __init__(self, render_mode: Optional[str] = None, hardcore: bool = False):
|
||||
EzPickle.__init__(self)
|
||||
EzPickle.__init__(self, render_mode, hardcore)
|
||||
self.isopen = True
|
||||
|
||||
self.world = Box2D.b2World()
|
||||
|
@@ -200,12 +200,20 @@ class CarRacing(gym.Env, EzPickle):
|
||||
domain_randomize: bool = False,
|
||||
continuous: bool = True,
|
||||
):
|
||||
EzPickle.__init__(self)
|
||||
EzPickle.__init__(
|
||||
self,
|
||||
render_mode,
|
||||
verbose,
|
||||
lap_complete_percent,
|
||||
domain_randomize,
|
||||
continuous,
|
||||
)
|
||||
self.continuous = continuous
|
||||
self.domain_randomize = domain_randomize
|
||||
self.lap_complete_percent = lap_complete_percent
|
||||
self._init_colors()
|
||||
|
||||
self.contactListener_keepref = FrictionDetector(self, lap_complete_percent)
|
||||
self.contactListener_keepref = FrictionDetector(self, self.lap_complete_percent)
|
||||
self.world = Box2D.b2World((0, 0), contactListener=self.contactListener_keepref)
|
||||
self.screen: Optional[pygame.Surface] = None
|
||||
self.surf = None
|
||||
@@ -480,6 +488,10 @@ class CarRacing(gym.Env, EzPickle):
|
||||
):
|
||||
super().reset(seed=seed)
|
||||
self._destroy()
|
||||
self.world.contactListener_bug_workaround = FrictionDetector(
|
||||
self, self.lap_complete_percent
|
||||
)
|
||||
self.world.contactListener = self.world.contactListener_bug_workaround
|
||||
self.reward = 0.0
|
||||
self.prev_reward = 0.0
|
||||
self.tile_visited_count = 0
|
||||
|
@@ -192,7 +192,15 @@ class LunarLander(gym.Env, EzPickle):
|
||||
wind_power: float = 15.0,
|
||||
turbulence_power: float = 1.5,
|
||||
):
|
||||
EzPickle.__init__(self)
|
||||
EzPickle.__init__(
|
||||
self,
|
||||
render_mode,
|
||||
continuous,
|
||||
gravity,
|
||||
enable_wind,
|
||||
wind_power,
|
||||
turbulence_power,
|
||||
)
|
||||
|
||||
assert (
|
||||
-12.0 < gravity and gravity < 0.0
|
||||
|
@@ -24,7 +24,7 @@ class AntEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
MuJocoPyEnv.__init__(
|
||||
self, "ant.xml", 5, observation_space=observation_space, **kwargs
|
||||
)
|
||||
utils.EzPickle.__init__(self)
|
||||
utils.EzPickle.__init__(self, **kwargs)
|
||||
|
||||
def step(self, a):
|
||||
xposbefore = self.get_body_com("torso")[0]
|
||||
|
@@ -34,7 +34,19 @@ class AntEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
exclude_current_positions_from_observation=True,
|
||||
**kwargs
|
||||
):
|
||||
utils.EzPickle.__init__(**locals())
|
||||
utils.EzPickle.__init__(
|
||||
self,
|
||||
xml_file,
|
||||
ctrl_cost_weight,
|
||||
contact_cost_weight,
|
||||
healthy_reward,
|
||||
terminate_when_unhealthy,
|
||||
healthy_z_range,
|
||||
contact_force_range,
|
||||
reset_noise_scale,
|
||||
exclude_current_positions_from_observation,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self._ctrl_cost_weight = ctrl_cost_weight
|
||||
self._contact_cost_weight = contact_cost_weight
|
||||
|
@@ -197,7 +197,20 @@ class AntEnv(MujocoEnv, utils.EzPickle):
|
||||
exclude_current_positions_from_observation=True,
|
||||
**kwargs
|
||||
):
|
||||
utils.EzPickle.__init__(**locals())
|
||||
utils.EzPickle.__init__(
|
||||
self,
|
||||
xml_file,
|
||||
ctrl_cost_weight,
|
||||
use_contact_forces,
|
||||
contact_cost_weight,
|
||||
healthy_reward,
|
||||
terminate_when_unhealthy,
|
||||
healthy_z_range,
|
||||
contact_force_range,
|
||||
reset_noise_scale,
|
||||
exclude_current_positions_from_observation,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self._ctrl_cost_weight = ctrl_cost_weight
|
||||
self._contact_cost_weight = contact_cost_weight
|
||||
|
@@ -22,7 +22,7 @@ class HalfCheetahEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
MuJocoPyEnv.__init__(
|
||||
self, "half_cheetah.xml", 5, observation_space=observation_space, **kwargs
|
||||
)
|
||||
utils.EzPickle.__init__(self)
|
||||
utils.EzPickle.__init__(self, **kwargs)
|
||||
|
||||
def step(self, action):
|
||||
xposbefore = self.sim.data.qpos[0]
|
||||
|
@@ -32,7 +32,15 @@ class HalfCheetahEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
exclude_current_positions_from_observation=True,
|
||||
**kwargs
|
||||
):
|
||||
utils.EzPickle.__init__(**locals())
|
||||
utils.EzPickle.__init__(
|
||||
self,
|
||||
xml_file,
|
||||
forward_reward_weight,
|
||||
ctrl_cost_weight,
|
||||
reset_noise_scale,
|
||||
exclude_current_positions_from_observation,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self._forward_reward_weight = forward_reward_weight
|
||||
|
||||
|
@@ -151,7 +151,14 @@ class HalfCheetahEnv(MujocoEnv, utils.EzPickle):
|
||||
exclude_current_positions_from_observation=True,
|
||||
**kwargs
|
||||
):
|
||||
utils.EzPickle.__init__(**locals())
|
||||
utils.EzPickle.__init__(
|
||||
self,
|
||||
forward_reward_weight,
|
||||
ctrl_cost_weight,
|
||||
reset_noise_scale,
|
||||
exclude_current_positions_from_observation,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self._forward_reward_weight = forward_reward_weight
|
||||
|
||||
|
@@ -22,7 +22,7 @@ class HopperEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
MuJocoPyEnv.__init__(
|
||||
self, "hopper.xml", 4, observation_space=observation_space, **kwargs
|
||||
)
|
||||
utils.EzPickle.__init__(self)
|
||||
utils.EzPickle.__init__(self, **kwargs)
|
||||
|
||||
def step(self, a):
|
||||
posbefore = self.sim.data.qpos[0]
|
||||
|
@@ -40,7 +40,20 @@ class HopperEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
exclude_current_positions_from_observation=True,
|
||||
**kwargs
|
||||
):
|
||||
utils.EzPickle.__init__(**locals())
|
||||
utils.EzPickle.__init__(
|
||||
self,
|
||||
xml_file,
|
||||
forward_reward_weight,
|
||||
ctrl_cost_weight,
|
||||
healthy_reward,
|
||||
terminate_when_unhealthy,
|
||||
healthy_state_range,
|
||||
healthy_z_range,
|
||||
healthy_angle_range,
|
||||
reset_noise_scale,
|
||||
exclude_current_positions_from_observation,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self._forward_reward_weight = forward_reward_weight
|
||||
|
||||
|
@@ -162,7 +162,19 @@ class HopperEnv(MujocoEnv, utils.EzPickle):
|
||||
exclude_current_positions_from_observation=True,
|
||||
**kwargs
|
||||
):
|
||||
utils.EzPickle.__init__(**locals())
|
||||
utils.EzPickle.__init__(
|
||||
self,
|
||||
forward_reward_weight,
|
||||
ctrl_cost_weight,
|
||||
healthy_reward,
|
||||
terminate_when_unhealthy,
|
||||
healthy_state_range,
|
||||
healthy_z_range,
|
||||
healthy_angle_range,
|
||||
reset_noise_scale,
|
||||
exclude_current_positions_from_observation,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self._forward_reward_weight = forward_reward_weight
|
||||
|
||||
|
@@ -30,7 +30,7 @@ class HumanoidEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
MuJocoPyEnv.__init__(
|
||||
self, "humanoid.xml", 5, observation_space=observation_space, **kwargs
|
||||
)
|
||||
utils.EzPickle.__init__(self)
|
||||
utils.EzPickle.__init__(self, **kwargs)
|
||||
|
||||
def _get_obs(self):
|
||||
data = self.sim.data
|
||||
|
@@ -44,7 +44,20 @@ class HumanoidEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
exclude_current_positions_from_observation=True,
|
||||
**kwargs
|
||||
):
|
||||
utils.EzPickle.__init__(**locals())
|
||||
utils.EzPickle.__init__(
|
||||
self,
|
||||
xml_file,
|
||||
forward_reward_weight,
|
||||
ctrl_cost_weight,
|
||||
contact_cost_weight,
|
||||
contact_cost_range,
|
||||
healthy_reward,
|
||||
terminate_when_unhealthy,
|
||||
healthy_z_range,
|
||||
reset_noise_scale,
|
||||
exclude_current_positions_from_observation,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self._forward_reward_weight = forward_reward_weight
|
||||
self._ctrl_cost_weight = ctrl_cost_weight
|
||||
|
@@ -234,7 +234,17 @@ class HumanoidEnv(MujocoEnv, utils.EzPickle):
|
||||
exclude_current_positions_from_observation=True,
|
||||
**kwargs
|
||||
):
|
||||
utils.EzPickle.__init__(**locals())
|
||||
utils.EzPickle.__init__(
|
||||
self,
|
||||
forward_reward_weight,
|
||||
ctrl_cost_weight,
|
||||
healthy_reward,
|
||||
terminate_when_unhealthy,
|
||||
healthy_z_range,
|
||||
reset_noise_scale,
|
||||
exclude_current_positions_from_observation,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self._forward_reward_weight = forward_reward_weight
|
||||
self._ctrl_cost_weight = ctrl_cost_weight
|
||||
|
@@ -28,7 +28,7 @@ class HumanoidStandupEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
observation_space=observation_space,
|
||||
**kwargs
|
||||
)
|
||||
utils.EzPickle.__init__(self)
|
||||
utils.EzPickle.__init__(self, **kwargs)
|
||||
|
||||
def _get_obs(self):
|
||||
data = self.sim.data
|
||||
|
@@ -200,7 +200,7 @@ class HumanoidStandupEnv(MujocoEnv, utils.EzPickle):
|
||||
observation_space=observation_space,
|
||||
**kwargs
|
||||
)
|
||||
utils.EzPickle.__init__(self)
|
||||
utils.EzPickle.__init__(self, **kwargs)
|
||||
|
||||
def _get_obs(self):
|
||||
data = self.data
|
||||
|
@@ -26,7 +26,7 @@ class InvertedDoublePendulumEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
observation_space=observation_space,
|
||||
**kwargs
|
||||
)
|
||||
utils.EzPickle.__init__(self)
|
||||
utils.EzPickle.__init__(self, **kwargs)
|
||||
|
||||
def step(self, action):
|
||||
self.do_simulation(action, self.frame_skip)
|
||||
|
@@ -132,7 +132,7 @@ class InvertedDoublePendulumEnv(MujocoEnv, utils.EzPickle):
|
||||
observation_space=observation_space,
|
||||
**kwargs
|
||||
)
|
||||
utils.EzPickle.__init__(self)
|
||||
utils.EzPickle.__init__(self, **kwargs)
|
||||
|
||||
def step(self, action):
|
||||
self.do_simulation(action, self.frame_skip)
|
||||
|
@@ -18,7 +18,7 @@ class InvertedPendulumEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
utils.EzPickle.__init__(self)
|
||||
utils.EzPickle.__init__(self, **kwargs)
|
||||
observation_space = Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64)
|
||||
MuJocoPyEnv.__init__(
|
||||
self,
|
||||
|
@@ -95,7 +95,7 @@ class InvertedPendulumEnv(MujocoEnv, utils.EzPickle):
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
utils.EzPickle.__init__(self)
|
||||
utils.EzPickle.__init__(self, **kwargs)
|
||||
observation_space = Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64)
|
||||
MujocoEnv.__init__(
|
||||
self,
|
||||
|
@@ -18,7 +18,7 @@ class PusherEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
utils.EzPickle.__init__(self)
|
||||
utils.EzPickle.__init__(self, **kwargs)
|
||||
observation_space = Box(low=-np.inf, high=np.inf, shape=(23,), dtype=np.float64)
|
||||
MuJocoPyEnv.__init__(
|
||||
self, "pusher.xml", 5, observation_space=observation_space, **kwargs
|
||||
|
@@ -140,7 +140,7 @@ class PusherEnv(MujocoEnv, utils.EzPickle):
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
utils.EzPickle.__init__(self)
|
||||
utils.EzPickle.__init__(self, **kwargs)
|
||||
observation_space = Box(low=-np.inf, high=np.inf, shape=(23,), dtype=np.float64)
|
||||
MujocoEnv.__init__(
|
||||
self, "pusher.xml", 5, observation_space=observation_space, **kwargs
|
||||
|
@@ -18,7 +18,7 @@ class ReacherEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
utils.EzPickle.__init__(self)
|
||||
utils.EzPickle.__init__(self, **kwargs)
|
||||
observation_space = Box(low=-np.inf, high=np.inf, shape=(11,), dtype=np.float64)
|
||||
MuJocoPyEnv.__init__(
|
||||
self, "reacher.xml", 2, observation_space=observation_space, **kwargs
|
||||
|
@@ -130,7 +130,7 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
utils.EzPickle.__init__(self)
|
||||
utils.EzPickle.__init__(self, **kwargs)
|
||||
observation_space = Box(low=-np.inf, high=np.inf, shape=(11,), dtype=np.float64)
|
||||
MujocoEnv.__init__(
|
||||
self, "reacher.xml", 2, observation_space=observation_space, **kwargs
|
||||
|
@@ -22,7 +22,7 @@ class SwimmerEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
MuJocoPyEnv.__init__(
|
||||
self, "swimmer.xml", 4, observation_space=observation_space, **kwargs
|
||||
)
|
||||
utils.EzPickle.__init__(self)
|
||||
utils.EzPickle.__init__(self, **kwargs)
|
||||
|
||||
def step(self, a):
|
||||
ctrl_cost_coeff = 0.0001
|
||||
|
@@ -30,7 +30,15 @@ class SwimmerEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
exclude_current_positions_from_observation=True,
|
||||
**kwargs
|
||||
):
|
||||
utils.EzPickle.__init__(**locals())
|
||||
utils.EzPickle.__init__(
|
||||
self,
|
||||
xml_file,
|
||||
forward_reward_weight,
|
||||
ctrl_cost_weight,
|
||||
reset_noise_scale,
|
||||
exclude_current_positions_from_observation,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self._forward_reward_weight = forward_reward_weight
|
||||
self._ctrl_cost_weight = ctrl_cost_weight
|
||||
|
@@ -143,7 +143,14 @@ class SwimmerEnv(MujocoEnv, utils.EzPickle):
|
||||
exclude_current_positions_from_observation=True,
|
||||
**kwargs
|
||||
):
|
||||
utils.EzPickle.__init__(**locals())
|
||||
utils.EzPickle.__init__(
|
||||
self,
|
||||
forward_reward_weight,
|
||||
ctrl_cost_weight,
|
||||
reset_noise_scale,
|
||||
exclude_current_positions_from_observation,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self._forward_reward_weight = forward_reward_weight
|
||||
self._ctrl_cost_weight = ctrl_cost_weight
|
||||
|
@@ -22,7 +22,7 @@ class Walker2dEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
MuJocoPyEnv.__init__(
|
||||
self, "walker2d.xml", 4, observation_space=observation_space, **kwargs
|
||||
)
|
||||
utils.EzPickle.__init__(self)
|
||||
utils.EzPickle.__init__(self, **kwargs)
|
||||
|
||||
def step(self, a):
|
||||
posbefore = self.sim.data.qpos[0]
|
||||
|
@@ -37,7 +37,19 @@ class Walker2dEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
exclude_current_positions_from_observation=True,
|
||||
**kwargs
|
||||
):
|
||||
utils.EzPickle.__init__(**locals())
|
||||
utils.EzPickle.__init__(
|
||||
self,
|
||||
xml_file,
|
||||
forward_reward_weight,
|
||||
ctrl_cost_weight,
|
||||
healthy_reward,
|
||||
terminate_when_unhealthy,
|
||||
healthy_z_range,
|
||||
healthy_angle_range,
|
||||
reset_noise_scale,
|
||||
exclude_current_positions_from_observation,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self._forward_reward_weight = forward_reward_weight
|
||||
self._ctrl_cost_weight = ctrl_cost_weight
|
||||
|
@@ -166,7 +166,18 @@ class Walker2dEnv(MujocoEnv, utils.EzPickle):
|
||||
exclude_current_positions_from_observation=True,
|
||||
**kwargs
|
||||
):
|
||||
utils.EzPickle.__init__(**locals())
|
||||
utils.EzPickle.__init__(
|
||||
self,
|
||||
forward_reward_weight,
|
||||
ctrl_cost_weight,
|
||||
healthy_reward,
|
||||
terminate_when_unhealthy,
|
||||
healthy_z_range,
|
||||
healthy_angle_range,
|
||||
reset_noise_scale,
|
||||
exclude_current_positions_from_observation,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self._forward_reward_weight = forward_reward_weight
|
||||
self._ctrl_cost_weight = ctrl_cost_weight
|
||||
|
@@ -1,9 +1,16 @@
|
||||
import pickle
|
||||
|
||||
import pytest
|
||||
|
||||
import gym
|
||||
from gym.envs.registration import EnvSpec
|
||||
from gym.utils.env_checker import check_env
|
||||
from tests.envs.utils import all_testing_env_specs, assert_equals, gym_testing_env_specs
|
||||
from gym.utils.env_checker import check_env, data_equivalence
|
||||
from tests.envs.utils import (
|
||||
all_testing_env_specs,
|
||||
all_testing_initialised_envs,
|
||||
assert_equals,
|
||||
gym_testing_env_specs,
|
||||
)
|
||||
|
||||
# This runs a smoketest on each official registered env. We may want
|
||||
# to try also running environments which are not officially registered envs.
|
||||
@@ -120,3 +127,19 @@ def test_render_modes(spec):
|
||||
new_env.render()
|
||||
new_env.close()
|
||||
env.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env",
|
||||
all_testing_initialised_envs,
|
||||
ids=[env.spec.id for env in all_testing_initialised_envs],
|
||||
)
|
||||
def test_pickle_env(env: gym.Env):
|
||||
pickled_env = pickle.loads(pickle.dumps(env))
|
||||
|
||||
data_equivalence(env.reset(), pickled_env.reset())
|
||||
|
||||
action = env.action_space.sample()
|
||||
data_equivalence(env.step(action), pickled_env.step(action))
|
||||
env.close()
|
||||
pickled_env.close()
|
||||
|
Reference in New Issue
Block a user