2016-05-29 09:07:09 -07:00
|
|
|
import numpy as np
|
|
|
|
from nose2 import tools
|
|
|
|
import os
|
|
|
|
|
|
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
import gym
|
2016-05-30 18:07:59 -07:00
|
|
|
from gym import envs, spaces
|
2016-05-29 09:07:09 -07:00
|
|
|
|
2016-05-31 00:57:31 -07:00
|
|
|
from test_envs import should_skip_env_spec_for_tests
|
|
|
|
|
2016-05-29 09:07:09 -07:00
|
|
|
specs = [spec for spec in envs.registry.all() if spec._entry_point is not None]
|
|
|
|
@tools.params(*specs)
|
|
|
|
def test_env(spec):
|
2016-05-31 00:57:31 -07:00
|
|
|
if should_skip_env_spec_for_tests(spec):
|
2016-05-29 09:07:09 -07:00
|
|
|
return
|
|
|
|
|
2016-05-30 18:07:59 -07:00
|
|
|
# Note that this precludes running this test in multiple
|
|
|
|
# threads. However, we probably already can't do multithreading
|
|
|
|
# due to some environments.
|
|
|
|
spaces.seed(0)
|
|
|
|
|
2016-05-29 09:07:09 -07:00
|
|
|
env1 = spec.make()
|
|
|
|
env1.seed(0)
|
|
|
|
action_samples1 = [env1.action_space.sample() for i in range(4)]
|
|
|
|
observation_samples1 = [env1.observation_space.sample() for i in range(4)]
|
|
|
|
initial_observation1 = env1.reset()
|
|
|
|
step_responses1 = [env1.step(action) for action in action_samples1]
|
|
|
|
env1.close()
|
|
|
|
|
2016-05-30 18:07:59 -07:00
|
|
|
spaces.seed(0)
|
|
|
|
|
2016-05-29 09:07:09 -07:00
|
|
|
env2 = spec.make()
|
|
|
|
env2.seed(0)
|
|
|
|
action_samples2 = [env2.action_space.sample() for i in range(4)]
|
|
|
|
observation_samples2 = [env2.observation_space.sample() for i in range(4)]
|
|
|
|
initial_observation2 = env2.reset()
|
|
|
|
step_responses2 = [env2.step(action) for action in action_samples2]
|
|
|
|
env2.close()
|
|
|
|
|
|
|
|
for i, (action_sample1, action_sample2) in enumerate(zip(action_samples1, action_samples2)):
|
2016-06-12 14:08:05 -07:00
|
|
|
assert_equals(action_sample1, action_sample2), '[{}] action_sample1: {}, action_sample2: {}'.format(i, action_sample1, action_sample2)
|
2016-05-29 09:07:09 -07:00
|
|
|
|
2016-06-12 14:08:05 -07:00
|
|
|
for (observation_sample1, observation_sample2) in zip(observation_samples1, observation_samples2):
|
|
|
|
assert_equals(observation_sample1, observation_sample2)
|
2016-05-29 09:07:09 -07:00
|
|
|
|
2016-05-29 09:26:52 -07:00
|
|
|
# Don't check rollout equality if it's a a nondeterministic
|
2016-05-29 09:07:09 -07:00
|
|
|
# environment.
|
2016-05-29 09:26:52 -07:00
|
|
|
if spec.nondeterministic:
|
2016-05-29 09:07:09 -07:00
|
|
|
return
|
|
|
|
|
2016-06-12 14:08:05 -07:00
|
|
|
assert_equals(initial_observation1, initial_observation2)
|
2016-05-29 09:07:09 -07:00
|
|
|
|
|
|
|
for i, ((o1, r1, d1, i1), (o2, r2, d2, i2)) in enumerate(zip(step_responses1, step_responses2)):
|
|
|
|
assert_equals(o1, o2, '[{}] '.format(i))
|
|
|
|
assert r1 == r2, '[{}] r1: {}, r2: {}'.format(i, r1, r2)
|
|
|
|
assert d1 == d2, '[{}] d1: {}, d2: {}'.format(i, d1, d2)
|
|
|
|
|
|
|
|
# Go returns a Pachi game board in info, which doesn't
|
|
|
|
# properly check equality. For now, we hack around this by
|
|
|
|
# just skipping Go.
|
|
|
|
if spec.id not in ['Go9x9-v0', 'Go19x19-v0']:
|
|
|
|
assert_equals(i1, i2, '[{}] '.format(i))
|
|
|
|
|
|
|
|
def assert_equals(a, b, prefix=None):
|
|
|
|
assert type(a) == type(b), "{}Differing types: {} and {}".format(prefix, a, b)
|
|
|
|
if isinstance(a, dict):
|
|
|
|
assert list(a.keys()) == list(b.keys()), "{}Key sets differ: {} and {}".format(prefix, a, b)
|
|
|
|
|
|
|
|
for k in a.keys():
|
|
|
|
v_a = a[k]
|
|
|
|
v_b = b[k]
|
|
|
|
assert_equals(v_a, v_b)
|
|
|
|
elif isinstance(a, np.ndarray):
|
|
|
|
np.testing.assert_array_equal(a, b)
|
2016-06-12 14:08:05 -07:00
|
|
|
elif isinstance(a, tuple):
|
|
|
|
for elem_from_a, elem_from_b in zip(a, b):
|
|
|
|
assert_equals(elem_from_a, elem_from_b)
|
2016-05-29 09:07:09 -07:00
|
|
|
else:
|
|
|
|
assert a == b
|