Remove pytest class in test pixel observation wrapper (#2902)

This commit is contained in:
Mark Towers
2022-06-19 22:05:56 +01:00
committed by GitHub
parent 979407f4c4
commit feea527a4f

View File

@@ -48,76 +48,76 @@ class FakeDictObservationEnvironment(FakeEnvironment):
super().__init__(*args, **kwargs)
class TestPixelObservationWrapper(gym.Wrapper):
@pytest.mark.parametrize("pixels_only", (True, False))
def test_dict_observation(self, pixels_only):
pixel_key = "rgb"
@pytest.mark.parametrize("pixels_only", (True, False))
def test_dict_observation(pixels_only):
pixel_key = "rgb"
env = FakeDictObservationEnvironment()
env = FakeDictObservationEnvironment()
# Make sure we are testing the right environment for the test.
observation_space = env.observation_space
assert isinstance(observation_space, spaces.Dict)
# Make sure we are testing the right environment for the test.
observation_space = env.observation_space
assert isinstance(observation_space, spaces.Dict)
width, height = (320, 240)
width, height = (320, 240)
# The wrapper should only add one observation.
wrapped_env = PixelObservationWrapper(
env,
pixel_keys=(pixel_key,),
pixels_only=pixels_only,
render_kwargs={pixel_key: {"width": width, "height": height}},
# The wrapper should only add one observation.
wrapped_env = PixelObservationWrapper(
env,
pixel_keys=(pixel_key,),
pixels_only=pixels_only,
render_kwargs={pixel_key: {"width": width, "height": height}},
)
assert isinstance(wrapped_env.observation_space, spaces.Dict)
if pixels_only:
assert len(wrapped_env.observation_space.spaces) == 1
assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
else:
assert (
len(wrapped_env.observation_space.spaces)
== len(observation_space.spaces) + 1
)
expected_keys = list(observation_space.spaces.keys()) + [pixel_key]
assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys
assert isinstance(wrapped_env.observation_space, spaces.Dict)
# Check that the added space item is consistent with the added observation.
observation = wrapped_env.reset()
rgb_observation = observation[pixel_key]
if pixels_only:
assert len(wrapped_env.observation_space.spaces) == 1
assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
else:
assert (
len(wrapped_env.observation_space.spaces)
== len(observation_space.spaces) + 1
)
expected_keys = list(observation_space.spaces.keys()) + [pixel_key]
assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys
assert rgb_observation.shape == (height, width, 3)
assert rgb_observation.dtype == np.uint8
# Check that the added space item is consistent with the added observation.
observation = wrapped_env.reset()
rgb_observation = observation[pixel_key]
assert rgb_observation.shape == (height, width, 3)
assert rgb_observation.dtype == np.uint8
@pytest.mark.parametrize("pixels_only", (True, False))
def test_single_array_observation(pixels_only):
pixel_key = "depth"
@pytest.mark.parametrize("pixels_only", (True, False))
def test_single_array_observation(self, pixels_only):
pixel_key = "depth"
env = FakeArrayObservationEnvironment()
observation_space = env.observation_space
assert isinstance(observation_space, spaces.Box)
env = FakeArrayObservationEnvironment()
observation_space = env.observation_space
assert isinstance(observation_space, spaces.Box)
wrapped_env = PixelObservationWrapper(
env, pixel_keys=(pixel_key,), pixels_only=pixels_only
)
wrapped_env.observation_space = wrapped_env.observation_space
assert isinstance(wrapped_env.observation_space, spaces.Dict)
wrapped_env = PixelObservationWrapper(
env, pixel_keys=(pixel_key,), pixels_only=pixels_only
)
wrapped_env.observation_space = wrapped_env.observation_space
assert isinstance(wrapped_env.observation_space, spaces.Dict)
if pixels_only:
assert len(wrapped_env.observation_space.spaces) == 1
assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
else:
assert len(wrapped_env.observation_space.spaces) == 2
assert list(wrapped_env.observation_space.spaces.keys()) == [
STATE_KEY,
pixel_key,
]
if pixels_only:
assert len(wrapped_env.observation_space.spaces) == 1
assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
else:
assert len(wrapped_env.observation_space.spaces) == 2
assert list(wrapped_env.observation_space.spaces.keys()) == [
STATE_KEY,
pixel_key,
]
observation = wrapped_env.reset()
depth_observation = observation[pixel_key]
observation = wrapped_env.reset()
depth_observation = observation[pixel_key]
assert depth_observation.shape == (32, 32, 3)
assert depth_observation.dtype == np.uint8
assert depth_observation.shape == (32, 32, 3)
assert depth_observation.dtype == np.uint8
if not pixels_only:
assert isinstance(observation[STATE_KEY], np.ndarray)
if not pixels_only:
assert isinstance(observation[STATE_KEY], np.ndarray)