From 51c2026f19e23ce497f48846f6989474fac2478e Mon Sep 17 00:00:00 2001 From: Ariel Kwiatkowski Date: Tue, 16 Aug 2022 18:05:36 +0200 Subject: [PATCH] Fix unpickling Box2D and MuJoCo envs (#3025) * Try to fix car racing unpickling * Fix EzPickle for BipedalWalker and LunarLander * Shamelessly steal the pickle-unpickle test from Mark, with slight modifications * CarRacing EzPickle fix * Mujoco ezpickle fix --- gym/envs/box2d/bipedal_walker.py | 2 +- gym/envs/box2d/car_racing.py | 16 +++++++++-- gym/envs/box2d/lunar_lander.py | 10 ++++++- gym/envs/mujoco/ant.py | 2 +- gym/envs/mujoco/ant_v3.py | 14 +++++++++- gym/envs/mujoco/ant_v4.py | 15 ++++++++++- gym/envs/mujoco/half_cheetah.py | 2 +- gym/envs/mujoco/half_cheetah_v3.py | 10 ++++++- gym/envs/mujoco/half_cheetah_v4.py | 9 ++++++- gym/envs/mujoco/hopper.py | 2 +- gym/envs/mujoco/hopper_v3.py | 15 ++++++++++- gym/envs/mujoco/hopper_v4.py | 14 +++++++++- gym/envs/mujoco/humanoid.py | 2 +- gym/envs/mujoco/humanoid_v3.py | 15 ++++++++++- gym/envs/mujoco/humanoid_v4.py | 12 ++++++++- gym/envs/mujoco/humanoidstandup.py | 2 +- gym/envs/mujoco/humanoidstandup_v4.py | 2 +- gym/envs/mujoco/inverted_double_pendulum.py | 2 +- .../mujoco/inverted_double_pendulum_v4.py | 2 +- gym/envs/mujoco/inverted_pendulum.py | 2 +- gym/envs/mujoco/inverted_pendulum_v4.py | 2 +- gym/envs/mujoco/pusher.py | 2 +- gym/envs/mujoco/pusher_v4.py | 2 +- gym/envs/mujoco/reacher.py | 2 +- gym/envs/mujoco/reacher_v4.py | 2 +- gym/envs/mujoco/swimmer.py | 2 +- gym/envs/mujoco/swimmer_v3.py | 10 ++++++- gym/envs/mujoco/swimmer_v4.py | 9 ++++++- gym/envs/mujoco/walker2d.py | 2 +- gym/envs/mujoco/walker2d_v3.py | 14 +++++++++- gym/envs/mujoco/walker2d_v4.py | 13 ++++++++- tests/envs/test_envs.py | 27 +++++++++++++++++-- 32 files changed, 203 insertions(+), 34 deletions(-) diff --git a/gym/envs/box2d/bipedal_walker.py b/gym/envs/box2d/bipedal_walker.py index 569de4deb..b25e27617 100644 --- a/gym/envs/box2d/bipedal_walker.py +++ b/gym/envs/box2d/bipedal_walker.py @@ -169,7 +169,7 @@ class BipedalWalker(gym.Env, EzPickle): } def __init__(self, render_mode: Optional[str] = None, hardcore: bool = False): - EzPickle.__init__(self) + EzPickle.__init__(self, render_mode, hardcore) self.isopen = True self.world = Box2D.b2World() diff --git a/gym/envs/box2d/car_racing.py b/gym/envs/box2d/car_racing.py index 188b3dd0c..968775c9b 100644 --- a/gym/envs/box2d/car_racing.py +++ b/gym/envs/box2d/car_racing.py @@ -200,12 +200,20 @@ class CarRacing(gym.Env, EzPickle): domain_randomize: bool = False, continuous: bool = True, ): - EzPickle.__init__(self) + EzPickle.__init__( + self, + render_mode, + verbose, + lap_complete_percent, + domain_randomize, + continuous, + ) self.continuous = continuous self.domain_randomize = domain_randomize + self.lap_complete_percent = lap_complete_percent self._init_colors() - self.contactListener_keepref = FrictionDetector(self, lap_complete_percent) + self.contactListener_keepref = FrictionDetector(self, self.lap_complete_percent) self.world = Box2D.b2World((0, 0), contactListener=self.contactListener_keepref) self.screen: Optional[pygame.Surface] = None self.surf = None @@ -480,6 +488,10 @@ class CarRacing(gym.Env, EzPickle): ): super().reset(seed=seed) self._destroy() + self.world.contactListener_bug_workaround = FrictionDetector( + self, self.lap_complete_percent + ) + self.world.contactListener = self.world.contactListener_bug_workaround self.reward = 0.0 self.prev_reward = 0.0 self.tile_visited_count = 0 diff --git a/gym/envs/box2d/lunar_lander.py b/gym/envs/box2d/lunar_lander.py index 711544030..122dda067 100644 --- a/gym/envs/box2d/lunar_lander.py +++ b/gym/envs/box2d/lunar_lander.py @@ -192,7 +192,15 @@ class LunarLander(gym.Env, EzPickle): wind_power: float = 15.0, turbulence_power: float = 1.5, ): - EzPickle.__init__(self) + EzPickle.__init__( + self, + render_mode, + continuous, + gravity, + enable_wind, + wind_power, + turbulence_power, + ) assert ( -12.0 < gravity and gravity < 0.0 diff --git a/gym/envs/mujoco/ant.py b/gym/envs/mujoco/ant.py index 1a981633d..bb7c5875f 100644 --- a/gym/envs/mujoco/ant.py +++ b/gym/envs/mujoco/ant.py @@ -24,7 +24,7 @@ class AntEnv(MuJocoPyEnv, utils.EzPickle): MuJocoPyEnv.__init__( self, "ant.xml", 5, observation_space=observation_space, **kwargs ) - utils.EzPickle.__init__(self) + utils.EzPickle.__init__(self, **kwargs) def step(self, a): xposbefore = self.get_body_com("torso")[0] diff --git a/gym/envs/mujoco/ant_v3.py b/gym/envs/mujoco/ant_v3.py index 4a6fa6bd3..5902340df 100644 --- a/gym/envs/mujoco/ant_v3.py +++ b/gym/envs/mujoco/ant_v3.py @@ -34,7 +34,19 @@ class AntEnv(MuJocoPyEnv, utils.EzPickle): exclude_current_positions_from_observation=True, **kwargs ): - utils.EzPickle.__init__(**locals()) + utils.EzPickle.__init__( + self, + xml_file, + ctrl_cost_weight, + contact_cost_weight, + healthy_reward, + terminate_when_unhealthy, + healthy_z_range, + contact_force_range, + reset_noise_scale, + exclude_current_positions_from_observation, + **kwargs + ) self._ctrl_cost_weight = ctrl_cost_weight self._contact_cost_weight = contact_cost_weight diff --git a/gym/envs/mujoco/ant_v4.py b/gym/envs/mujoco/ant_v4.py index ffd071e22..45a1ee42d 100644 --- a/gym/envs/mujoco/ant_v4.py +++ b/gym/envs/mujoco/ant_v4.py @@ -197,7 +197,20 @@ class AntEnv(MujocoEnv, utils.EzPickle): exclude_current_positions_from_observation=True, **kwargs ): - utils.EzPickle.__init__(**locals()) + utils.EzPickle.__init__( + self, + xml_file, + ctrl_cost_weight, + use_contact_forces, + contact_cost_weight, + healthy_reward, + terminate_when_unhealthy, + healthy_z_range, + contact_force_range, + reset_noise_scale, + exclude_current_positions_from_observation, + **kwargs + ) self._ctrl_cost_weight = ctrl_cost_weight self._contact_cost_weight = contact_cost_weight diff --git a/gym/envs/mujoco/half_cheetah.py b/gym/envs/mujoco/half_cheetah.py index 069b4d146..5645b54a8 100644 --- a/gym/envs/mujoco/half_cheetah.py +++ b/gym/envs/mujoco/half_cheetah.py @@ -22,7 +22,7 @@ class HalfCheetahEnv(MuJocoPyEnv, utils.EzPickle): MuJocoPyEnv.__init__( self, "half_cheetah.xml", 5, observation_space=observation_space, **kwargs ) - utils.EzPickle.__init__(self) + utils.EzPickle.__init__(self, **kwargs) def step(self, action): xposbefore = self.sim.data.qpos[0] diff --git a/gym/envs/mujoco/half_cheetah_v3.py b/gym/envs/mujoco/half_cheetah_v3.py index 07d0d74c8..d57c8f9fa 100644 --- a/gym/envs/mujoco/half_cheetah_v3.py +++ b/gym/envs/mujoco/half_cheetah_v3.py @@ -32,7 +32,15 @@ class HalfCheetahEnv(MuJocoPyEnv, utils.EzPickle): exclude_current_positions_from_observation=True, **kwargs ): - utils.EzPickle.__init__(**locals()) + utils.EzPickle.__init__( + self, + xml_file, + forward_reward_weight, + ctrl_cost_weight, + reset_noise_scale, + exclude_current_positions_from_observation, + **kwargs + ) self._forward_reward_weight = forward_reward_weight diff --git a/gym/envs/mujoco/half_cheetah_v4.py b/gym/envs/mujoco/half_cheetah_v4.py index 7ecf6de62..9a1e1fc91 100644 --- a/gym/envs/mujoco/half_cheetah_v4.py +++ b/gym/envs/mujoco/half_cheetah_v4.py @@ -151,7 +151,14 @@ class HalfCheetahEnv(MujocoEnv, utils.EzPickle): exclude_current_positions_from_observation=True, **kwargs ): - utils.EzPickle.__init__(**locals()) + utils.EzPickle.__init__( + self, + forward_reward_weight, + ctrl_cost_weight, + reset_noise_scale, + exclude_current_positions_from_observation, + **kwargs + ) self._forward_reward_weight = forward_reward_weight diff --git a/gym/envs/mujoco/hopper.py b/gym/envs/mujoco/hopper.py index e0b9fa59c..30c45c819 100644 --- a/gym/envs/mujoco/hopper.py +++ b/gym/envs/mujoco/hopper.py @@ -22,7 +22,7 @@ class HopperEnv(MuJocoPyEnv, utils.EzPickle): MuJocoPyEnv.__init__( self, "hopper.xml", 4, observation_space=observation_space, **kwargs ) - utils.EzPickle.__init__(self) + utils.EzPickle.__init__(self, **kwargs) def step(self, a): posbefore = self.sim.data.qpos[0] diff --git a/gym/envs/mujoco/hopper_v3.py b/gym/envs/mujoco/hopper_v3.py index c3db1a9d3..522a8d8f3 100644 --- a/gym/envs/mujoco/hopper_v3.py +++ b/gym/envs/mujoco/hopper_v3.py @@ -40,7 +40,20 @@ class HopperEnv(MuJocoPyEnv, utils.EzPickle): exclude_current_positions_from_observation=True, **kwargs ): - utils.EzPickle.__init__(**locals()) + utils.EzPickle.__init__( + self, + xml_file, + forward_reward_weight, + ctrl_cost_weight, + healthy_reward, + terminate_when_unhealthy, + healthy_state_range, + healthy_z_range, + healthy_angle_range, + reset_noise_scale, + exclude_current_positions_from_observation, + **kwargs + ) self._forward_reward_weight = forward_reward_weight diff --git a/gym/envs/mujoco/hopper_v4.py b/gym/envs/mujoco/hopper_v4.py index 3d2acb6e1..34571accc 100644 --- a/gym/envs/mujoco/hopper_v4.py +++ b/gym/envs/mujoco/hopper_v4.py @@ -162,7 +162,19 @@ class HopperEnv(MujocoEnv, utils.EzPickle): exclude_current_positions_from_observation=True, **kwargs ): - utils.EzPickle.__init__(**locals()) + utils.EzPickle.__init__( + self, + forward_reward_weight, + ctrl_cost_weight, + healthy_reward, + terminate_when_unhealthy, + healthy_state_range, + healthy_z_range, + healthy_angle_range, + reset_noise_scale, + exclude_current_positions_from_observation, + **kwargs + ) self._forward_reward_weight = forward_reward_weight diff --git a/gym/envs/mujoco/humanoid.py b/gym/envs/mujoco/humanoid.py index b67179ee8..dfc529711 100644 --- a/gym/envs/mujoco/humanoid.py +++ b/gym/envs/mujoco/humanoid.py @@ -30,7 +30,7 @@ class HumanoidEnv(MuJocoPyEnv, utils.EzPickle): MuJocoPyEnv.__init__( self, "humanoid.xml", 5, observation_space=observation_space, **kwargs ) - utils.EzPickle.__init__(self) + utils.EzPickle.__init__(self, **kwargs) def _get_obs(self): data = self.sim.data diff --git a/gym/envs/mujoco/humanoid_v3.py b/gym/envs/mujoco/humanoid_v3.py index 81c4b3d6f..d58c6542f 100644 --- a/gym/envs/mujoco/humanoid_v3.py +++ b/gym/envs/mujoco/humanoid_v3.py @@ -44,7 +44,20 @@ class HumanoidEnv(MuJocoPyEnv, utils.EzPickle): exclude_current_positions_from_observation=True, **kwargs ): - utils.EzPickle.__init__(**locals()) + utils.EzPickle.__init__( + self, + xml_file, + forward_reward_weight, + ctrl_cost_weight, + contact_cost_weight, + contact_cost_range, + healthy_reward, + terminate_when_unhealthy, + healthy_z_range, + reset_noise_scale, + exclude_current_positions_from_observation, + **kwargs + ) self._forward_reward_weight = forward_reward_weight self._ctrl_cost_weight = ctrl_cost_weight diff --git a/gym/envs/mujoco/humanoid_v4.py b/gym/envs/mujoco/humanoid_v4.py index 46620d774..95987b8e1 100644 --- a/gym/envs/mujoco/humanoid_v4.py +++ b/gym/envs/mujoco/humanoid_v4.py @@ -234,7 +234,17 @@ class HumanoidEnv(MujocoEnv, utils.EzPickle): exclude_current_positions_from_observation=True, **kwargs ): - utils.EzPickle.__init__(**locals()) + utils.EzPickle.__init__( + self, + forward_reward_weight, + ctrl_cost_weight, + healthy_reward, + terminate_when_unhealthy, + healthy_z_range, + reset_noise_scale, + exclude_current_positions_from_observation, + **kwargs + ) self._forward_reward_weight = forward_reward_weight self._ctrl_cost_weight = ctrl_cost_weight diff --git a/gym/envs/mujoco/humanoidstandup.py b/gym/envs/mujoco/humanoidstandup.py index 06ef153e9..48e31c297 100644 --- a/gym/envs/mujoco/humanoidstandup.py +++ b/gym/envs/mujoco/humanoidstandup.py @@ -28,7 +28,7 @@ class HumanoidStandupEnv(MuJocoPyEnv, utils.EzPickle): observation_space=observation_space, **kwargs ) - utils.EzPickle.__init__(self) + utils.EzPickle.__init__(self, **kwargs) def _get_obs(self): data = self.sim.data diff --git a/gym/envs/mujoco/humanoidstandup_v4.py b/gym/envs/mujoco/humanoidstandup_v4.py index b333f3fc9..4376421d8 100644 --- a/gym/envs/mujoco/humanoidstandup_v4.py +++ b/gym/envs/mujoco/humanoidstandup_v4.py @@ -200,7 +200,7 @@ class HumanoidStandupEnv(MujocoEnv, utils.EzPickle): observation_space=observation_space, **kwargs ) - utils.EzPickle.__init__(self) + utils.EzPickle.__init__(self, **kwargs) def _get_obs(self): data = self.data diff --git a/gym/envs/mujoco/inverted_double_pendulum.py b/gym/envs/mujoco/inverted_double_pendulum.py index 3f41fc077..99f3931d2 100644 --- a/gym/envs/mujoco/inverted_double_pendulum.py +++ b/gym/envs/mujoco/inverted_double_pendulum.py @@ -26,7 +26,7 @@ class InvertedDoublePendulumEnv(MuJocoPyEnv, utils.EzPickle): observation_space=observation_space, **kwargs ) - utils.EzPickle.__init__(self) + utils.EzPickle.__init__(self, **kwargs) def step(self, action): self.do_simulation(action, self.frame_skip) diff --git a/gym/envs/mujoco/inverted_double_pendulum_v4.py b/gym/envs/mujoco/inverted_double_pendulum_v4.py index c9d472d15..ece335794 100644 --- a/gym/envs/mujoco/inverted_double_pendulum_v4.py +++ b/gym/envs/mujoco/inverted_double_pendulum_v4.py @@ -132,7 +132,7 @@ class InvertedDoublePendulumEnv(MujocoEnv, utils.EzPickle): observation_space=observation_space, **kwargs ) - utils.EzPickle.__init__(self) + utils.EzPickle.__init__(self, **kwargs) def step(self, action): self.do_simulation(action, self.frame_skip) diff --git a/gym/envs/mujoco/inverted_pendulum.py b/gym/envs/mujoco/inverted_pendulum.py index e41dbe45a..6c3592505 100644 --- a/gym/envs/mujoco/inverted_pendulum.py +++ b/gym/envs/mujoco/inverted_pendulum.py @@ -18,7 +18,7 @@ class InvertedPendulumEnv(MuJocoPyEnv, utils.EzPickle): } def __init__(self, **kwargs): - utils.EzPickle.__init__(self) + utils.EzPickle.__init__(self, **kwargs) observation_space = Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64) MuJocoPyEnv.__init__( self, diff --git a/gym/envs/mujoco/inverted_pendulum_v4.py b/gym/envs/mujoco/inverted_pendulum_v4.py index cc029672f..7cb27fc45 100644 --- a/gym/envs/mujoco/inverted_pendulum_v4.py +++ b/gym/envs/mujoco/inverted_pendulum_v4.py @@ -95,7 +95,7 @@ class InvertedPendulumEnv(MujocoEnv, utils.EzPickle): } def __init__(self, **kwargs): - utils.EzPickle.__init__(self) + utils.EzPickle.__init__(self, **kwargs) observation_space = Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64) MujocoEnv.__init__( self, diff --git a/gym/envs/mujoco/pusher.py b/gym/envs/mujoco/pusher.py index b4c7fe1c1..57d2f70d7 100644 --- a/gym/envs/mujoco/pusher.py +++ b/gym/envs/mujoco/pusher.py @@ -18,7 +18,7 @@ class PusherEnv(MuJocoPyEnv, utils.EzPickle): } def __init__(self, **kwargs): - utils.EzPickle.__init__(self) + utils.EzPickle.__init__(self, **kwargs) observation_space = Box(low=-np.inf, high=np.inf, shape=(23,), dtype=np.float64) MuJocoPyEnv.__init__( self, "pusher.xml", 5, observation_space=observation_space, **kwargs diff --git a/gym/envs/mujoco/pusher_v4.py b/gym/envs/mujoco/pusher_v4.py index 306e272d2..c89a750d5 100644 --- a/gym/envs/mujoco/pusher_v4.py +++ b/gym/envs/mujoco/pusher_v4.py @@ -140,7 +140,7 @@ class PusherEnv(MujocoEnv, utils.EzPickle): } def __init__(self, **kwargs): - utils.EzPickle.__init__(self) + utils.EzPickle.__init__(self, **kwargs) observation_space = Box(low=-np.inf, high=np.inf, shape=(23,), dtype=np.float64) MujocoEnv.__init__( self, "pusher.xml", 5, observation_space=observation_space, **kwargs diff --git a/gym/envs/mujoco/reacher.py b/gym/envs/mujoco/reacher.py index 73438666b..c852934ed 100644 --- a/gym/envs/mujoco/reacher.py +++ b/gym/envs/mujoco/reacher.py @@ -18,7 +18,7 @@ class ReacherEnv(MuJocoPyEnv, utils.EzPickle): } def __init__(self, **kwargs): - utils.EzPickle.__init__(self) + utils.EzPickle.__init__(self, **kwargs) observation_space = Box(low=-np.inf, high=np.inf, shape=(11,), dtype=np.float64) MuJocoPyEnv.__init__( self, "reacher.xml", 2, observation_space=observation_space, **kwargs diff --git a/gym/envs/mujoco/reacher_v4.py b/gym/envs/mujoco/reacher_v4.py index f0c334c0e..e3ede8d94 100644 --- a/gym/envs/mujoco/reacher_v4.py +++ b/gym/envs/mujoco/reacher_v4.py @@ -130,7 +130,7 @@ class ReacherEnv(MujocoEnv, utils.EzPickle): } def __init__(self, **kwargs): - utils.EzPickle.__init__(self) + utils.EzPickle.__init__(self, **kwargs) observation_space = Box(low=-np.inf, high=np.inf, shape=(11,), dtype=np.float64) MujocoEnv.__init__( self, "reacher.xml", 2, observation_space=observation_space, **kwargs diff --git a/gym/envs/mujoco/swimmer.py b/gym/envs/mujoco/swimmer.py index 137a97eb5..79a7dcf1a 100644 --- a/gym/envs/mujoco/swimmer.py +++ b/gym/envs/mujoco/swimmer.py @@ -22,7 +22,7 @@ class SwimmerEnv(MuJocoPyEnv, utils.EzPickle): MuJocoPyEnv.__init__( self, "swimmer.xml", 4, observation_space=observation_space, **kwargs ) - utils.EzPickle.__init__(self) + utils.EzPickle.__init__(self, **kwargs) def step(self, a): ctrl_cost_coeff = 0.0001 diff --git a/gym/envs/mujoco/swimmer_v3.py b/gym/envs/mujoco/swimmer_v3.py index d17c09963..955c52acc 100644 --- a/gym/envs/mujoco/swimmer_v3.py +++ b/gym/envs/mujoco/swimmer_v3.py @@ -30,7 +30,15 @@ class SwimmerEnv(MuJocoPyEnv, utils.EzPickle): exclude_current_positions_from_observation=True, **kwargs ): - utils.EzPickle.__init__(**locals()) + utils.EzPickle.__init__( + self, + xml_file, + forward_reward_weight, + ctrl_cost_weight, + reset_noise_scale, + exclude_current_positions_from_observation, + **kwargs + ) self._forward_reward_weight = forward_reward_weight self._ctrl_cost_weight = ctrl_cost_weight diff --git a/gym/envs/mujoco/swimmer_v4.py b/gym/envs/mujoco/swimmer_v4.py index 0be0dc36a..363cd0f35 100644 --- a/gym/envs/mujoco/swimmer_v4.py +++ b/gym/envs/mujoco/swimmer_v4.py @@ -143,7 +143,14 @@ class SwimmerEnv(MujocoEnv, utils.EzPickle): exclude_current_positions_from_observation=True, **kwargs ): - utils.EzPickle.__init__(**locals()) + utils.EzPickle.__init__( + self, + forward_reward_weight, + ctrl_cost_weight, + reset_noise_scale, + exclude_current_positions_from_observation, + **kwargs + ) self._forward_reward_weight = forward_reward_weight self._ctrl_cost_weight = ctrl_cost_weight diff --git a/gym/envs/mujoco/walker2d.py b/gym/envs/mujoco/walker2d.py index 12ec9630f..f3e74f2a3 100644 --- a/gym/envs/mujoco/walker2d.py +++ b/gym/envs/mujoco/walker2d.py @@ -22,7 +22,7 @@ class Walker2dEnv(MuJocoPyEnv, utils.EzPickle): MuJocoPyEnv.__init__( self, "walker2d.xml", 4, observation_space=observation_space, **kwargs ) - utils.EzPickle.__init__(self) + utils.EzPickle.__init__(self, **kwargs) def step(self, a): posbefore = self.sim.data.qpos[0] diff --git a/gym/envs/mujoco/walker2d_v3.py b/gym/envs/mujoco/walker2d_v3.py index 8688804fe..afd99f867 100644 --- a/gym/envs/mujoco/walker2d_v3.py +++ b/gym/envs/mujoco/walker2d_v3.py @@ -37,7 +37,19 @@ class Walker2dEnv(MuJocoPyEnv, utils.EzPickle): exclude_current_positions_from_observation=True, **kwargs ): - utils.EzPickle.__init__(**locals()) + utils.EzPickle.__init__( + self, + xml_file, + forward_reward_weight, + ctrl_cost_weight, + healthy_reward, + terminate_when_unhealthy, + healthy_z_range, + healthy_angle_range, + reset_noise_scale, + exclude_current_positions_from_observation, + **kwargs + ) self._forward_reward_weight = forward_reward_weight self._ctrl_cost_weight = ctrl_cost_weight diff --git a/gym/envs/mujoco/walker2d_v4.py b/gym/envs/mujoco/walker2d_v4.py index 795480fb7..a01f25dbb 100644 --- a/gym/envs/mujoco/walker2d_v4.py +++ b/gym/envs/mujoco/walker2d_v4.py @@ -166,7 +166,18 @@ class Walker2dEnv(MujocoEnv, utils.EzPickle): exclude_current_positions_from_observation=True, **kwargs ): - utils.EzPickle.__init__(**locals()) + utils.EzPickle.__init__( + self, + forward_reward_weight, + ctrl_cost_weight, + healthy_reward, + terminate_when_unhealthy, + healthy_z_range, + healthy_angle_range, + reset_noise_scale, + exclude_current_positions_from_observation, + **kwargs + ) self._forward_reward_weight = forward_reward_weight self._ctrl_cost_weight = ctrl_cost_weight diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index db9312a21..89017fdb0 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -1,9 +1,16 @@ +import pickle + import pytest import gym from gym.envs.registration import EnvSpec -from gym.utils.env_checker import check_env -from tests.envs.utils import all_testing_env_specs, assert_equals, gym_testing_env_specs +from gym.utils.env_checker import check_env, data_equivalence +from tests.envs.utils import ( + all_testing_env_specs, + all_testing_initialised_envs, + assert_equals, + gym_testing_env_specs, +) # This runs a smoketest on each official registered env. We may want # to try also running environments which are not officially registered envs. @@ -120,3 +127,19 @@ def test_render_modes(spec): new_env.render() new_env.close() env.close() + + +@pytest.mark.parametrize( + "env", + all_testing_initialised_envs, + ids=[env.spec.id for env in all_testing_initialised_envs], +) +def test_pickle_env(env: gym.Env): + pickled_env = pickle.loads(pickle.dumps(env)) + + data_equivalence(env.reset(), pickled_env.reset()) + + action = env.action_space.sample() + data_equivalence(env.step(action), pickled_env.step(action)) + env.close() + pickled_env.close()