2022-05-12 15:33:48 +02:00
|
|
|
"""Test environment determinism by performing a rollout."""
|
|
|
|
|
2016-05-29 09:07:09 -07:00
|
|
|
import numpy as np
|
2017-02-11 22:17:02 -08:00
|
|
|
import pytest
|
2019-01-30 22:39:55 +01:00
|
|
|
|
2021-09-29 01:53:30 +02:00
|
|
|
from tests.envs.spec_list import spec_list
|
2016-05-29 09:07:09 -07:00
|
|
|
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2022-04-10 18:36:23 +01:00
|
|
|
@pytest.mark.parametrize("spec", spec_list, ids=[spec.id for spec in spec_list])
|
2016-05-29 09:07:09 -07:00
|
|
|
def test_env(spec):
|
2022-05-12 15:33:48 +02:00
|
|
|
"""Run a rollout with two environments and assert equality.
|
|
|
|
|
|
|
|
This test run a rollout of NUM_STEPS steps with two environments
|
|
|
|
initialized with the same seed and assert that:
|
|
|
|
|
|
|
|
- observation after first reset are the same
|
|
|
|
- same actions are sampled by the two envs
|
|
|
|
- observations are contained in the observation space
|
|
|
|
- obs, rew, done and info are equals between the two envs
|
|
|
|
|
|
|
|
Args:
|
|
|
|
spec (EnvSpec): Environment specification
|
|
|
|
|
|
|
|
"""
|
2016-05-30 18:07:59 -07:00
|
|
|
# Note that this precludes running this test in multiple
|
|
|
|
# threads. However, we probably already can't do multithreading
|
|
|
|
# due to some environments.
|
2022-05-12 15:33:48 +02:00
|
|
|
SEED = 0
|
|
|
|
NUM_STEPS = 50
|
2016-05-29 09:07:09 -07:00
|
|
|
|
2022-05-12 15:33:48 +02:00
|
|
|
env1, env2 = spec.make(), spec.make()
|
|
|
|
|
|
|
|
initial_observation1 = env1.reset(seed=SEED)
|
|
|
|
initial_observation2 = env2.reset(seed=SEED)
|
|
|
|
|
|
|
|
env1.action_space.seed(SEED)
|
|
|
|
env2.action_space.seed(SEED)
|
|
|
|
|
|
|
|
assert_equals(initial_observation1, initial_observation2)
|
|
|
|
|
|
|
|
for i in range(NUM_STEPS):
|
|
|
|
action1 = env1.action_space.sample()
|
|
|
|
action2 = env2.action_space.sample()
|
2016-05-29 09:07:09 -07:00
|
|
|
|
2017-03-28 11:25:21 -07:00
|
|
|
try:
|
2022-05-12 15:33:48 +02:00
|
|
|
assert_equals(action1, action2)
|
2017-03-28 11:25:21 -07:00
|
|
|
except AssertionError:
|
2021-07-29 02:26:34 +02:00
|
|
|
print("env1.action_space=", env1.action_space)
|
|
|
|
print("env2.action_space=", env2.action_space)
|
2022-05-12 15:33:48 +02:00
|
|
|
print("action_samples1=", action1)
|
|
|
|
print("action_samples2=", action2)
|
|
|
|
print(f"[{i}] action_sample1: {action1}, action_sample2: {action2}")
|
2017-03-28 11:25:21 -07:00
|
|
|
raise
|
2016-05-29 09:07:09 -07:00
|
|
|
|
2022-05-12 15:33:48 +02:00
|
|
|
# Don't check rollout equality if it's a a nondeterministic
|
|
|
|
# environment.
|
|
|
|
if spec.nondeterministic:
|
|
|
|
return
|
2016-05-29 09:07:09 -07:00
|
|
|
|
2022-05-12 15:33:48 +02:00
|
|
|
obs1, rew1, done1, info1 = env1.step(action1)
|
|
|
|
obs2, rew2, done2, info2 = env2.step(action2)
|
|
|
|
|
|
|
|
assert_equals(obs1, obs2, f"[{i}] ")
|
|
|
|
|
|
|
|
assert env1.observation_space.contains(obs1)
|
|
|
|
assert env2.observation_space.contains(obs2)
|
2016-05-29 09:07:09 -07:00
|
|
|
|
2022-05-12 15:33:48 +02:00
|
|
|
assert rew1 == rew2, f"[{i}] r1: {rew1}, r2: {rew2}"
|
|
|
|
assert done1 == done2, f"[{i}] d1: {done1}, d2: {done2}"
|
|
|
|
assert_equals(info1, info2, f"[{i}] ")
|
2016-05-29 09:07:09 -07:00
|
|
|
|
2022-05-12 15:33:48 +02:00
|
|
|
if done1: # done2 verified in previous assertion
|
|
|
|
env1.reset(seed=SEED)
|
|
|
|
env2.reset(seed=SEED)
|
|
|
|
|
|
|
|
env1.close()
|
|
|
|
env2.close()
|
2021-07-29 02:26:34 +02:00
|
|
|
|
2016-05-29 09:07:09 -07:00
|
|
|
|
|
|
|
def assert_equals(a, b, prefix=None):
|
2022-05-12 15:33:48 +02:00
|
|
|
"""Assert equality of data structures `a` and `b`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
a: first data structure
|
|
|
|
b: second data structure
|
|
|
|
prefix: prefix for failed assertion message for types and dicts
|
|
|
|
|
|
|
|
"""
|
2022-01-11 18:12:05 +01:00
|
|
|
assert type(a) == type(b), f"{prefix}Differing types: {a} and {b}"
|
2016-05-29 09:07:09 -07:00
|
|
|
if isinstance(a, dict):
|
2022-01-11 18:12:05 +01:00
|
|
|
assert list(a.keys()) == list(b.keys()), f"{prefix}Key sets differ: {a} and {b}"
|
2016-05-29 09:07:09 -07:00
|
|
|
|
|
|
|
for k in a.keys():
|
|
|
|
v_a = a[k]
|
|
|
|
v_b = b[k]
|
|
|
|
assert_equals(v_a, v_b)
|
|
|
|
elif isinstance(a, np.ndarray):
|
|
|
|
np.testing.assert_array_equal(a, b)
|
2016-06-12 14:08:05 -07:00
|
|
|
elif isinstance(a, tuple):
|
|
|
|
for elem_from_a, elem_from_b in zip(a, b):
|
|
|
|
assert_equals(elem_from_a, elem_from_b)
|
2016-05-29 09:07:09 -07:00
|
|
|
else:
|
|
|
|
assert a == b
|