Fix pickle env testing (#1034)

This commit is contained in:
Mark Towers
2024-04-18 23:55:36 +01:00
committed by GitHub
parent 6a8a267c5b
commit 1b2c1ff084

View File

@@ -149,11 +149,17 @@ def test_pickle_env(env: gym.Env):
if env.metadata.get("jax", False):
env = gym.wrappers.JaxToNumpy(env)
pickled_env = pickle.loads(pickle.dumps(env))
data_equivalence(env.reset(), pickled_env.reset())
action = env.action_space.sample()
data_equivalence(env.step(action), pickled_env.step(action))
env_reset = env.reset(seed=123)
env_step = env.step(action)
pickled_env = pickle.loads(pickle.dumps(env))
pickle_reset = pickled_env.reset(seed=123)
pickle_step = pickled_env.step(action)
assert data_equivalence(env_reset, pickle_reset)
assert data_equivalence(env_step, pickle_step)
env.close()
pickled_env.close()