Remove pytest class in test pixel observation wrapper (#2902)

This commit is contained in:
Mark Towers
2022-06-19 22:05:56 +01:00
committed by GitHub
parent 979407f4c4
commit feea527a4f

View File

@@ -48,76 +48,76 @@ class FakeDictObservationEnvironment(FakeEnvironment):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
class TestPixelObservationWrapper(gym.Wrapper): @pytest.mark.parametrize("pixels_only", (True, False))
@pytest.mark.parametrize("pixels_only", (True, False)) def test_dict_observation(pixels_only):
def test_dict_observation(self, pixels_only): pixel_key = "rgb"
pixel_key = "rgb"
env = FakeDictObservationEnvironment() env = FakeDictObservationEnvironment()
# Make sure we are testing the right environment for the test. # Make sure we are testing the right environment for the test.
observation_space = env.observation_space observation_space = env.observation_space
assert isinstance(observation_space, spaces.Dict) assert isinstance(observation_space, spaces.Dict)
width, height = (320, 240) width, height = (320, 240)
# The wrapper should only add one observation. # The wrapper should only add one observation.
wrapped_env = PixelObservationWrapper( wrapped_env = PixelObservationWrapper(
env, env,
pixel_keys=(pixel_key,), pixel_keys=(pixel_key,),
pixels_only=pixels_only, pixels_only=pixels_only,
render_kwargs={pixel_key: {"width": width, "height": height}}, render_kwargs={pixel_key: {"width": width, "height": height}},
)
assert isinstance(wrapped_env.observation_space, spaces.Dict)
if pixels_only:
assert len(wrapped_env.observation_space.spaces) == 1
assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
else:
assert (
len(wrapped_env.observation_space.spaces)
== len(observation_space.spaces) + 1
) )
expected_keys = list(observation_space.spaces.keys()) + [pixel_key]
assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys
assert isinstance(wrapped_env.observation_space, spaces.Dict) # Check that the added space item is consistent with the added observation.
observation = wrapped_env.reset()
rgb_observation = observation[pixel_key]
if pixels_only: assert rgb_observation.shape == (height, width, 3)
assert len(wrapped_env.observation_space.spaces) == 1 assert rgb_observation.dtype == np.uint8
assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
else:
assert (
len(wrapped_env.observation_space.spaces)
== len(observation_space.spaces) + 1
)
expected_keys = list(observation_space.spaces.keys()) + [pixel_key]
assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys
# Check that the added space item is consistent with the added observation.
observation = wrapped_env.reset()
rgb_observation = observation[pixel_key]
assert rgb_observation.shape == (height, width, 3) @pytest.mark.parametrize("pixels_only", (True, False))
assert rgb_observation.dtype == np.uint8 def test_single_array_observation(pixels_only):
pixel_key = "depth"
@pytest.mark.parametrize("pixels_only", (True, False)) env = FakeArrayObservationEnvironment()
def test_single_array_observation(self, pixels_only): observation_space = env.observation_space
pixel_key = "depth" assert isinstance(observation_space, spaces.Box)
env = FakeArrayObservationEnvironment() wrapped_env = PixelObservationWrapper(
observation_space = env.observation_space env, pixel_keys=(pixel_key,), pixels_only=pixels_only
assert isinstance(observation_space, spaces.Box) )
wrapped_env.observation_space = wrapped_env.observation_space
assert isinstance(wrapped_env.observation_space, spaces.Dict)
wrapped_env = PixelObservationWrapper( if pixels_only:
env, pixel_keys=(pixel_key,), pixels_only=pixels_only assert len(wrapped_env.observation_space.spaces) == 1
) assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
wrapped_env.observation_space = wrapped_env.observation_space else:
assert isinstance(wrapped_env.observation_space, spaces.Dict) assert len(wrapped_env.observation_space.spaces) == 2
assert list(wrapped_env.observation_space.spaces.keys()) == [
STATE_KEY,
pixel_key,
]
if pixels_only: observation = wrapped_env.reset()
assert len(wrapped_env.observation_space.spaces) == 1 depth_observation = observation[pixel_key]
assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
else:
assert len(wrapped_env.observation_space.spaces) == 2
assert list(wrapped_env.observation_space.spaces.keys()) == [
STATE_KEY,
pixel_key,
]
observation = wrapped_env.reset() assert depth_observation.shape == (32, 32, 3)
depth_observation = observation[pixel_key] assert depth_observation.dtype == np.uint8
assert depth_observation.shape == (32, 32, 3) if not pixels_only:
assert depth_observation.dtype == np.uint8 assert isinstance(observation[STATE_KEY], np.ndarray)
if not pixels_only:
assert isinstance(observation[STATE_KEY], np.ndarray)