102 lines
3.3 KiB
Python
102 lines
3.3 KiB
Python
"""
|
|
Tests for asynchronous vectorized environments.
|
|
"""
|
|
|
|
import gym
|
|
import numpy as np
|
|
import pytest
|
|
from .dummy_vec_env import DummyVecEnv
|
|
from .shmem_vec_env import ShmemVecEnv
|
|
from .subproc_vec_env import SubprocVecEnv
|
|
|
|
|
|
def assert_envs_equal(env1, env2, num_steps):
|
|
"""
|
|
Compare two environments over num_steps steps and make sure
|
|
that the observations produced by each are the same when given
|
|
the same actions.
|
|
"""
|
|
assert env1.num_envs == env2.num_envs
|
|
assert env1.action_space.shape == env2.action_space.shape
|
|
assert env1.action_space.dtype == env2.action_space.dtype
|
|
joint_shape = (env1.num_envs,) + env1.action_space.shape
|
|
|
|
try:
|
|
obs1, obs2 = env1.reset(), env2.reset()
|
|
assert np.array(obs1).shape == np.array(obs2).shape
|
|
assert np.array(obs1).shape == joint_shape
|
|
assert np.allclose(obs1, obs2)
|
|
np.random.seed(1337)
|
|
for _ in range(num_steps):
|
|
actions = np.array(np.random.randint(0, 0x100, size=joint_shape),
|
|
dtype=env1.action_space.dtype)
|
|
for env in [env1, env2]:
|
|
env.step_async(actions)
|
|
outs1 = env1.step_wait()
|
|
outs2 = env2.step_wait()
|
|
for out1, out2 in zip(outs1[:3], outs2[:3]):
|
|
assert np.array(out1).shape == np.array(out2).shape
|
|
assert np.allclose(out1, out2)
|
|
assert list(outs1[3]) == list(outs2[3])
|
|
finally:
|
|
env1.close()
|
|
env2.close()
|
|
|
|
|
|
@pytest.mark.parametrize('klass', (ShmemVecEnv, SubprocVecEnv))
|
|
@pytest.mark.parametrize('dtype', ('uint8', 'float32'))
|
|
def test_vec_env(klass, dtype): # pylint: disable=R0914
|
|
"""
|
|
Test that a vectorized environment is equivalent to
|
|
DummyVecEnv, since DummyVecEnv is less likely to be
|
|
error prone.
|
|
"""
|
|
num_envs = 3
|
|
num_steps = 100
|
|
shape = (3, 8)
|
|
|
|
def make_fn(seed):
|
|
"""
|
|
Get an environment constructor with a seed.
|
|
"""
|
|
return lambda: SimpleEnv(seed, shape, dtype)
|
|
fns = [make_fn(i) for i in range(num_envs)]
|
|
env1 = DummyVecEnv(fns)
|
|
env2 = klass(fns)
|
|
assert_envs_equal(env1, env2, num_steps=num_steps)
|
|
|
|
|
|
class SimpleEnv(gym.Env):
|
|
"""
|
|
An environment with a pre-determined observation space
|
|
and RNG seed.
|
|
"""
|
|
|
|
def __init__(self, seed, shape, dtype):
|
|
np.random.seed(seed)
|
|
self._dtype = dtype
|
|
self._start_obs = np.array(np.random.randint(0, 0x100, size=shape),
|
|
dtype=dtype)
|
|
self._max_steps = seed + 1
|
|
self._cur_obs = None
|
|
self._cur_step = 0
|
|
# this is 0xFF instead of 0x100 because the Box space includes
|
|
# the high end, while randint does not
|
|
self.action_space = gym.spaces.Box(low=0, high=0xFF, shape=shape, dtype=dtype)
|
|
self.observation_space = self.action_space
|
|
|
|
def step(self, action):
|
|
self._cur_obs += np.array(action, dtype=self._dtype)
|
|
self._cur_step += 1
|
|
done = self._cur_step >= self._max_steps
|
|
reward = self._cur_step / self._max_steps
|
|
return self._cur_obs, reward, done, {'foo': 'bar' + str(reward)}
|
|
|
|
def reset(self):
|
|
self._cur_obs = self._start_obs
|
|
self._cur_step = 0
|
|
return self._cur_obs
|
|
|
|
def render(self, mode=None):
|
|
raise NotImplementedError
|