Fix unpickling Box2D and MuJoCo envs (#3025)

* Try to fix car racing unpickling

* Fix EzPickle for BipedalWalker and LunarLander

* Shamelessly steal the pickle-unpickle test from Mark, with slight modifications

* CarRacing EzPickle fix

* Mujoco ezpickle fix
This commit is contained in:
Ariel Kwiatkowski
2022-08-16 18:05:36 +02:00
committed by GitHub
parent f54319e742
commit 51c2026f19
32 changed files with 203 additions and 34 deletions

View File

@@ -169,7 +169,7 @@ class BipedalWalker(gym.Env, EzPickle):
}
def __init__(self, render_mode: Optional[str] = None, hardcore: bool = False):
EzPickle.__init__(self)
EzPickle.__init__(self, render_mode, hardcore)
self.isopen = True
self.world = Box2D.b2World()

View File

@@ -200,12 +200,20 @@ class CarRacing(gym.Env, EzPickle):
domain_randomize: bool = False,
continuous: bool = True,
):
EzPickle.__init__(self)
EzPickle.__init__(
self,
render_mode,
verbose,
lap_complete_percent,
domain_randomize,
continuous,
)
self.continuous = continuous
self.domain_randomize = domain_randomize
self.lap_complete_percent = lap_complete_percent
self._init_colors()
self.contactListener_keepref = FrictionDetector(self, lap_complete_percent)
self.contactListener_keepref = FrictionDetector(self, self.lap_complete_percent)
self.world = Box2D.b2World((0, 0), contactListener=self.contactListener_keepref)
self.screen: Optional[pygame.Surface] = None
self.surf = None
@@ -480,6 +488,10 @@ class CarRacing(gym.Env, EzPickle):
):
super().reset(seed=seed)
self._destroy()
self.world.contactListener_bug_workaround = FrictionDetector(
self, self.lap_complete_percent
)
self.world.contactListener = self.world.contactListener_bug_workaround
self.reward = 0.0
self.prev_reward = 0.0
self.tile_visited_count = 0

View File

@@ -192,7 +192,15 @@ class LunarLander(gym.Env, EzPickle):
wind_power: float = 15.0,
turbulence_power: float = 1.5,
):
EzPickle.__init__(self)
EzPickle.__init__(
self,
render_mode,
continuous,
gravity,
enable_wind,
wind_power,
turbulence_power,
)
assert (
-12.0 < gravity and gravity < 0.0

View File

@@ -24,7 +24,7 @@ class AntEnv(MuJocoPyEnv, utils.EzPickle):
MuJocoPyEnv.__init__(
self, "ant.xml", 5, observation_space=observation_space, **kwargs
)
utils.EzPickle.__init__(self)
utils.EzPickle.__init__(self, **kwargs)
def step(self, a):
xposbefore = self.get_body_com("torso")[0]

View File

@@ -34,7 +34,19 @@ class AntEnv(MuJocoPyEnv, utils.EzPickle):
exclude_current_positions_from_observation=True,
**kwargs
):
utils.EzPickle.__init__(**locals())
utils.EzPickle.__init__(
self,
xml_file,
ctrl_cost_weight,
contact_cost_weight,
healthy_reward,
terminate_when_unhealthy,
healthy_z_range,
contact_force_range,
reset_noise_scale,
exclude_current_positions_from_observation,
**kwargs
)
self._ctrl_cost_weight = ctrl_cost_weight
self._contact_cost_weight = contact_cost_weight

View File

@@ -197,7 +197,20 @@ class AntEnv(MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation=True,
**kwargs
):
utils.EzPickle.__init__(**locals())
utils.EzPickle.__init__(
self,
xml_file,
ctrl_cost_weight,
use_contact_forces,
contact_cost_weight,
healthy_reward,
terminate_when_unhealthy,
healthy_z_range,
contact_force_range,
reset_noise_scale,
exclude_current_positions_from_observation,
**kwargs
)
self._ctrl_cost_weight = ctrl_cost_weight
self._contact_cost_weight = contact_cost_weight

View File

@@ -22,7 +22,7 @@ class HalfCheetahEnv(MuJocoPyEnv, utils.EzPickle):
MuJocoPyEnv.__init__(
self, "half_cheetah.xml", 5, observation_space=observation_space, **kwargs
)
utils.EzPickle.__init__(self)
utils.EzPickle.__init__(self, **kwargs)
def step(self, action):
xposbefore = self.sim.data.qpos[0]

View File

@@ -32,7 +32,15 @@ class HalfCheetahEnv(MuJocoPyEnv, utils.EzPickle):
exclude_current_positions_from_observation=True,
**kwargs
):
utils.EzPickle.__init__(**locals())
utils.EzPickle.__init__(
self,
xml_file,
forward_reward_weight,
ctrl_cost_weight,
reset_noise_scale,
exclude_current_positions_from_observation,
**kwargs
)
self._forward_reward_weight = forward_reward_weight

View File

@@ -151,7 +151,14 @@ class HalfCheetahEnv(MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation=True,
**kwargs
):
utils.EzPickle.__init__(**locals())
utils.EzPickle.__init__(
self,
forward_reward_weight,
ctrl_cost_weight,
reset_noise_scale,
exclude_current_positions_from_observation,
**kwargs
)
self._forward_reward_weight = forward_reward_weight

View File

@@ -22,7 +22,7 @@ class HopperEnv(MuJocoPyEnv, utils.EzPickle):
MuJocoPyEnv.__init__(
self, "hopper.xml", 4, observation_space=observation_space, **kwargs
)
utils.EzPickle.__init__(self)
utils.EzPickle.__init__(self, **kwargs)
def step(self, a):
posbefore = self.sim.data.qpos[0]

View File

@@ -40,7 +40,20 @@ class HopperEnv(MuJocoPyEnv, utils.EzPickle):
exclude_current_positions_from_observation=True,
**kwargs
):
utils.EzPickle.__init__(**locals())
utils.EzPickle.__init__(
self,
xml_file,
forward_reward_weight,
ctrl_cost_weight,
healthy_reward,
terminate_when_unhealthy,
healthy_state_range,
healthy_z_range,
healthy_angle_range,
reset_noise_scale,
exclude_current_positions_from_observation,
**kwargs
)
self._forward_reward_weight = forward_reward_weight

View File

@@ -162,7 +162,19 @@ class HopperEnv(MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation=True,
**kwargs
):
utils.EzPickle.__init__(**locals())
utils.EzPickle.__init__(
self,
forward_reward_weight,
ctrl_cost_weight,
healthy_reward,
terminate_when_unhealthy,
healthy_state_range,
healthy_z_range,
healthy_angle_range,
reset_noise_scale,
exclude_current_positions_from_observation,
**kwargs
)
self._forward_reward_weight = forward_reward_weight

View File

@@ -30,7 +30,7 @@ class HumanoidEnv(MuJocoPyEnv, utils.EzPickle):
MuJocoPyEnv.__init__(
self, "humanoid.xml", 5, observation_space=observation_space, **kwargs
)
utils.EzPickle.__init__(self)
utils.EzPickle.__init__(self, **kwargs)
def _get_obs(self):
data = self.sim.data

View File

@@ -44,7 +44,20 @@ class HumanoidEnv(MuJocoPyEnv, utils.EzPickle):
exclude_current_positions_from_observation=True,
**kwargs
):
utils.EzPickle.__init__(**locals())
utils.EzPickle.__init__(
self,
xml_file,
forward_reward_weight,
ctrl_cost_weight,
contact_cost_weight,
contact_cost_range,
healthy_reward,
terminate_when_unhealthy,
healthy_z_range,
reset_noise_scale,
exclude_current_positions_from_observation,
**kwargs
)
self._forward_reward_weight = forward_reward_weight
self._ctrl_cost_weight = ctrl_cost_weight

View File

@@ -234,7 +234,17 @@ class HumanoidEnv(MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation=True,
**kwargs
):
utils.EzPickle.__init__(**locals())
utils.EzPickle.__init__(
self,
forward_reward_weight,
ctrl_cost_weight,
healthy_reward,
terminate_when_unhealthy,
healthy_z_range,
reset_noise_scale,
exclude_current_positions_from_observation,
**kwargs
)
self._forward_reward_weight = forward_reward_weight
self._ctrl_cost_weight = ctrl_cost_weight

View File

@@ -28,7 +28,7 @@ class HumanoidStandupEnv(MuJocoPyEnv, utils.EzPickle):
observation_space=observation_space,
**kwargs
)
utils.EzPickle.__init__(self)
utils.EzPickle.__init__(self, **kwargs)
def _get_obs(self):
data = self.sim.data

View File

@@ -200,7 +200,7 @@ class HumanoidStandupEnv(MujocoEnv, utils.EzPickle):
observation_space=observation_space,
**kwargs
)
utils.EzPickle.__init__(self)
utils.EzPickle.__init__(self, **kwargs)
def _get_obs(self):
data = self.data

View File

@@ -26,7 +26,7 @@ class InvertedDoublePendulumEnv(MuJocoPyEnv, utils.EzPickle):
observation_space=observation_space,
**kwargs
)
utils.EzPickle.__init__(self)
utils.EzPickle.__init__(self, **kwargs)
def step(self, action):
self.do_simulation(action, self.frame_skip)

View File

@@ -132,7 +132,7 @@ class InvertedDoublePendulumEnv(MujocoEnv, utils.EzPickle):
observation_space=observation_space,
**kwargs
)
utils.EzPickle.__init__(self)
utils.EzPickle.__init__(self, **kwargs)
def step(self, action):
self.do_simulation(action, self.frame_skip)

View File

@@ -18,7 +18,7 @@ class InvertedPendulumEnv(MuJocoPyEnv, utils.EzPickle):
}
def __init__(self, **kwargs):
utils.EzPickle.__init__(self)
utils.EzPickle.__init__(self, **kwargs)
observation_space = Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64)
MuJocoPyEnv.__init__(
self,

View File

@@ -95,7 +95,7 @@ class InvertedPendulumEnv(MujocoEnv, utils.EzPickle):
}
def __init__(self, **kwargs):
utils.EzPickle.__init__(self)
utils.EzPickle.__init__(self, **kwargs)
observation_space = Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64)
MujocoEnv.__init__(
self,

View File

@@ -18,7 +18,7 @@ class PusherEnv(MuJocoPyEnv, utils.EzPickle):
}
def __init__(self, **kwargs):
utils.EzPickle.__init__(self)
utils.EzPickle.__init__(self, **kwargs)
observation_space = Box(low=-np.inf, high=np.inf, shape=(23,), dtype=np.float64)
MuJocoPyEnv.__init__(
self, "pusher.xml", 5, observation_space=observation_space, **kwargs

View File

@@ -140,7 +140,7 @@ class PusherEnv(MujocoEnv, utils.EzPickle):
}
def __init__(self, **kwargs):
utils.EzPickle.__init__(self)
utils.EzPickle.__init__(self, **kwargs)
observation_space = Box(low=-np.inf, high=np.inf, shape=(23,), dtype=np.float64)
MujocoEnv.__init__(
self, "pusher.xml", 5, observation_space=observation_space, **kwargs

View File

@@ -18,7 +18,7 @@ class ReacherEnv(MuJocoPyEnv, utils.EzPickle):
}
def __init__(self, **kwargs):
utils.EzPickle.__init__(self)
utils.EzPickle.__init__(self, **kwargs)
observation_space = Box(low=-np.inf, high=np.inf, shape=(11,), dtype=np.float64)
MuJocoPyEnv.__init__(
self, "reacher.xml", 2, observation_space=observation_space, **kwargs

View File

@@ -130,7 +130,7 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
}
def __init__(self, **kwargs):
utils.EzPickle.__init__(self)
utils.EzPickle.__init__(self, **kwargs)
observation_space = Box(low=-np.inf, high=np.inf, shape=(11,), dtype=np.float64)
MujocoEnv.__init__(
self, "reacher.xml", 2, observation_space=observation_space, **kwargs

View File

@@ -22,7 +22,7 @@ class SwimmerEnv(MuJocoPyEnv, utils.EzPickle):
MuJocoPyEnv.__init__(
self, "swimmer.xml", 4, observation_space=observation_space, **kwargs
)
utils.EzPickle.__init__(self)
utils.EzPickle.__init__(self, **kwargs)
def step(self, a):
ctrl_cost_coeff = 0.0001

View File

@@ -30,7 +30,15 @@ class SwimmerEnv(MuJocoPyEnv, utils.EzPickle):
exclude_current_positions_from_observation=True,
**kwargs
):
utils.EzPickle.__init__(**locals())
utils.EzPickle.__init__(
self,
xml_file,
forward_reward_weight,
ctrl_cost_weight,
reset_noise_scale,
exclude_current_positions_from_observation,
**kwargs
)
self._forward_reward_weight = forward_reward_weight
self._ctrl_cost_weight = ctrl_cost_weight

View File

@@ -143,7 +143,14 @@ class SwimmerEnv(MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation=True,
**kwargs
):
utils.EzPickle.__init__(**locals())
utils.EzPickle.__init__(
self,
forward_reward_weight,
ctrl_cost_weight,
reset_noise_scale,
exclude_current_positions_from_observation,
**kwargs
)
self._forward_reward_weight = forward_reward_weight
self._ctrl_cost_weight = ctrl_cost_weight

View File

@@ -22,7 +22,7 @@ class Walker2dEnv(MuJocoPyEnv, utils.EzPickle):
MuJocoPyEnv.__init__(
self, "walker2d.xml", 4, observation_space=observation_space, **kwargs
)
utils.EzPickle.__init__(self)
utils.EzPickle.__init__(self, **kwargs)
def step(self, a):
posbefore = self.sim.data.qpos[0]

View File

@@ -37,7 +37,19 @@ class Walker2dEnv(MuJocoPyEnv, utils.EzPickle):
exclude_current_positions_from_observation=True,
**kwargs
):
utils.EzPickle.__init__(**locals())
utils.EzPickle.__init__(
self,
xml_file,
forward_reward_weight,
ctrl_cost_weight,
healthy_reward,
terminate_when_unhealthy,
healthy_z_range,
healthy_angle_range,
reset_noise_scale,
exclude_current_positions_from_observation,
**kwargs
)
self._forward_reward_weight = forward_reward_weight
self._ctrl_cost_weight = ctrl_cost_weight

View File

@@ -166,7 +166,18 @@ class Walker2dEnv(MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation=True,
**kwargs
):
utils.EzPickle.__init__(**locals())
utils.EzPickle.__init__(
self,
forward_reward_weight,
ctrl_cost_weight,
healthy_reward,
terminate_when_unhealthy,
healthy_z_range,
healthy_angle_range,
reset_noise_scale,
exclude_current_positions_from_observation,
**kwargs
)
self._forward_reward_weight = forward_reward_weight
self._ctrl_cost_weight = ctrl_cost_weight

View File

@@ -1,9 +1,16 @@
import pickle
import pytest
import gym
from gym.envs.registration import EnvSpec
from gym.utils.env_checker import check_env
from tests.envs.utils import all_testing_env_specs, assert_equals, gym_testing_env_specs
from gym.utils.env_checker import check_env, data_equivalence
from tests.envs.utils import (
all_testing_env_specs,
all_testing_initialised_envs,
assert_equals,
gym_testing_env_specs,
)
# This runs a smoketest on each official registered env. We may want
# to try also running environments which are not officially registered envs.
@@ -120,3 +127,19 @@ def test_render_modes(spec):
new_env.render()
new_env.close()
env.close()
@pytest.mark.parametrize(
"env",
all_testing_initialised_envs,
ids=[env.spec.id for env in all_testing_initialised_envs],
)
def test_pickle_env(env: gym.Env):
pickled_env = pickle.loads(pickle.dumps(env))
data_equivalence(env.reset(), pickled_env.reset())
action = env.action_space.sample()
data_equivalence(env.step(action), pickled_env.step(action))
env.close()
pickled_env.close()