From 1ea5ec647c724eeed4d2a05f0092cfb2fd0ff81f Mon Sep 17 00:00:00 2001 From: Christopher Hesse Date: Fri, 24 Aug 2018 15:44:56 -0700 Subject: [PATCH] export SimpleEnv and assert_envs_equal, fix minor bug in action space (#46) --- baselines/common/vec_env/test_vec_env.py | 62 +++++++++++++++--------- 1 file changed, 39 insertions(+), 23 deletions(-) diff --git a/baselines/common/vec_env/test_vec_env.py b/baselines/common/vec_env/test_vec_env.py index 6d0d41c..da2f3fb 100644 --- a/baselines/common/vec_env/test_vec_env.py +++ b/baselines/common/vec_env/test_vec_env.py @@ -10,6 +10,39 @@ from .shmem_vec_env import ShmemVecEnv from .subproc_vec_env import SubprocVecEnv +def assert_envs_equal(env1, env2, num_steps): + """ + Compare two environments over num_steps steps and make sure + that the observations produced by each are the same when given + the same actions. + """ + assert env1.num_envs == env2.num_envs + assert env1.action_space.shape == env2.action_space.shape + assert env1.action_space.dtype == env2.action_space.dtype + joint_shape = (env1.num_envs,) + env1.action_space.shape + + try: + obs1, obs2 = env1.reset(), env2.reset() + assert np.array(obs1).shape == np.array(obs2).shape + assert np.array(obs1).shape == joint_shape + assert np.allclose(obs1, obs2) + np.random.seed(1337) + for _ in range(num_steps): + actions = np.array(np.random.randint(0, 0x100, size=joint_shape), + dtype=env1.action_space.dtype) + for env in [env1, env2]: + env.step_async(actions) + outs1 = env1.step_wait() + outs2 = env2.step_wait() + for out1, out2 in zip(outs1[:3], outs2[:3]): + assert np.array(out1).shape == np.array(out2).shape + assert np.allclose(out1, out2) + assert list(outs1[3]) == list(outs2[3]) + finally: + env1.close() + env2.close() + + @pytest.mark.parametrize('klass', (ShmemVecEnv, SubprocVecEnv)) @pytest.mark.parametrize('dtype', ('uint8', 'float32')) def test_vec_env(klass, dtype): # pylint: disable=R0914 @@ -26,33 +59,14 @@ def test_vec_env(klass, dtype): # pylint: disable=R0914 """ Get an environment constructor with a seed. """ - return lambda: _SimpleEnv(seed, shape, dtype) + return lambda: SimpleEnv(seed, shape, dtype) fns = [make_fn(i) for i in range(num_envs)] env1 = DummyVecEnv(fns) env2 = klass(fns) - try: - obs1, obs2 = env1.reset(), env2.reset() - assert np.array(obs1).shape == np.array(obs2).shape - assert np.allclose(obs1, obs2) - np.random.seed(1337) - for _ in range(num_steps): - joint_shape = (len(fns),) + shape - actions = np.array(np.random.randint(0, 0x100, size=joint_shape), - dtype=dtype) - for env in [env1, env2]: - env.step_async(actions) - outs1 = env1.step_wait() - outs2 = env2.step_wait() - for out1, out2 in zip(outs1[:3], outs2[:3]): - assert np.array(out1).shape == np.array(out2).shape - assert np.allclose(out1, out2) - assert list(outs1[3]) == list(outs2[3]) - finally: - env1.close() - env2.close() + assert_envs_equal(env1, env2, num_steps=num_steps) -class _SimpleEnv(gym.Env): +class SimpleEnv(gym.Env): """ An environment with a pre-determined observation space and RNG seed. @@ -66,7 +80,9 @@ class _SimpleEnv(gym.Env): self._max_steps = seed + 1 self._cur_obs = None self._cur_step = 0 - self.action_space = gym.spaces.Box(low=0, high=100, shape=shape, dtype=dtype) + # this is 0xFF instead of 0x100 because the Box space includes + # the high end, while randint does not + self.action_space = gym.spaces.Box(low=0, high=0xFF, shape=shape, dtype=dtype) self.observation_space = self.action_space def step(self, action):