mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-31 10:09:53 +00:00
Remove pytest class in test pixel observation wrapper (#2902)
This commit is contained in:
@@ -48,76 +48,76 @@ class FakeDictObservationEnvironment(FakeEnvironment):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class TestPixelObservationWrapper(gym.Wrapper):
|
@pytest.mark.parametrize("pixels_only", (True, False))
|
||||||
@pytest.mark.parametrize("pixels_only", (True, False))
|
def test_dict_observation(pixels_only):
|
||||||
def test_dict_observation(self, pixels_only):
|
pixel_key = "rgb"
|
||||||
pixel_key = "rgb"
|
|
||||||
|
|
||||||
env = FakeDictObservationEnvironment()
|
env = FakeDictObservationEnvironment()
|
||||||
|
|
||||||
# Make sure we are testing the right environment for the test.
|
# Make sure we are testing the right environment for the test.
|
||||||
observation_space = env.observation_space
|
observation_space = env.observation_space
|
||||||
assert isinstance(observation_space, spaces.Dict)
|
assert isinstance(observation_space, spaces.Dict)
|
||||||
|
|
||||||
width, height = (320, 240)
|
width, height = (320, 240)
|
||||||
|
|
||||||
# The wrapper should only add one observation.
|
# The wrapper should only add one observation.
|
||||||
wrapped_env = PixelObservationWrapper(
|
wrapped_env = PixelObservationWrapper(
|
||||||
env,
|
env,
|
||||||
pixel_keys=(pixel_key,),
|
pixel_keys=(pixel_key,),
|
||||||
pixels_only=pixels_only,
|
pixels_only=pixels_only,
|
||||||
render_kwargs={pixel_key: {"width": width, "height": height}},
|
render_kwargs={pixel_key: {"width": width, "height": height}},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(wrapped_env.observation_space, spaces.Dict)
|
||||||
|
|
||||||
|
if pixels_only:
|
||||||
|
assert len(wrapped_env.observation_space.spaces) == 1
|
||||||
|
assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
len(wrapped_env.observation_space.spaces)
|
||||||
|
== len(observation_space.spaces) + 1
|
||||||
)
|
)
|
||||||
|
expected_keys = list(observation_space.spaces.keys()) + [pixel_key]
|
||||||
|
assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys
|
||||||
|
|
||||||
assert isinstance(wrapped_env.observation_space, spaces.Dict)
|
# Check that the added space item is consistent with the added observation.
|
||||||
|
observation = wrapped_env.reset()
|
||||||
|
rgb_observation = observation[pixel_key]
|
||||||
|
|
||||||
if pixels_only:
|
assert rgb_observation.shape == (height, width, 3)
|
||||||
assert len(wrapped_env.observation_space.spaces) == 1
|
assert rgb_observation.dtype == np.uint8
|
||||||
assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
len(wrapped_env.observation_space.spaces)
|
|
||||||
== len(observation_space.spaces) + 1
|
|
||||||
)
|
|
||||||
expected_keys = list(observation_space.spaces.keys()) + [pixel_key]
|
|
||||||
assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys
|
|
||||||
|
|
||||||
# Check that the added space item is consistent with the added observation.
|
|
||||||
observation = wrapped_env.reset()
|
|
||||||
rgb_observation = observation[pixel_key]
|
|
||||||
|
|
||||||
assert rgb_observation.shape == (height, width, 3)
|
@pytest.mark.parametrize("pixels_only", (True, False))
|
||||||
assert rgb_observation.dtype == np.uint8
|
def test_single_array_observation(pixels_only):
|
||||||
|
pixel_key = "depth"
|
||||||
|
|
||||||
@pytest.mark.parametrize("pixels_only", (True, False))
|
env = FakeArrayObservationEnvironment()
|
||||||
def test_single_array_observation(self, pixels_only):
|
observation_space = env.observation_space
|
||||||
pixel_key = "depth"
|
assert isinstance(observation_space, spaces.Box)
|
||||||
|
|
||||||
env = FakeArrayObservationEnvironment()
|
wrapped_env = PixelObservationWrapper(
|
||||||
observation_space = env.observation_space
|
env, pixel_keys=(pixel_key,), pixels_only=pixels_only
|
||||||
assert isinstance(observation_space, spaces.Box)
|
)
|
||||||
|
wrapped_env.observation_space = wrapped_env.observation_space
|
||||||
|
assert isinstance(wrapped_env.observation_space, spaces.Dict)
|
||||||
|
|
||||||
wrapped_env = PixelObservationWrapper(
|
if pixels_only:
|
||||||
env, pixel_keys=(pixel_key,), pixels_only=pixels_only
|
assert len(wrapped_env.observation_space.spaces) == 1
|
||||||
)
|
assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
|
||||||
wrapped_env.observation_space = wrapped_env.observation_space
|
else:
|
||||||
assert isinstance(wrapped_env.observation_space, spaces.Dict)
|
assert len(wrapped_env.observation_space.spaces) == 2
|
||||||
|
assert list(wrapped_env.observation_space.spaces.keys()) == [
|
||||||
|
STATE_KEY,
|
||||||
|
pixel_key,
|
||||||
|
]
|
||||||
|
|
||||||
if pixels_only:
|
observation = wrapped_env.reset()
|
||||||
assert len(wrapped_env.observation_space.spaces) == 1
|
depth_observation = observation[pixel_key]
|
||||||
assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
|
|
||||||
else:
|
|
||||||
assert len(wrapped_env.observation_space.spaces) == 2
|
|
||||||
assert list(wrapped_env.observation_space.spaces.keys()) == [
|
|
||||||
STATE_KEY,
|
|
||||||
pixel_key,
|
|
||||||
]
|
|
||||||
|
|
||||||
observation = wrapped_env.reset()
|
assert depth_observation.shape == (32, 32, 3)
|
||||||
depth_observation = observation[pixel_key]
|
assert depth_observation.dtype == np.uint8
|
||||||
|
|
||||||
assert depth_observation.shape == (32, 32, 3)
|
if not pixels_only:
|
||||||
assert depth_observation.dtype == np.uint8
|
assert isinstance(observation[STATE_KEY], np.ndarray)
|
||||||
|
|
||||||
if not pixels_only:
|
|
||||||
assert isinstance(observation[STATE_KEY], np.ndarray)
|
|
||||||
|
Reference in New Issue
Block a user