mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-30 01:50:19 +00:00
Remove pytest class in test pixel observation wrapper (#2902)
This commit is contained in:
@@ -48,76 +48,76 @@ class FakeDictObservationEnvironment(FakeEnvironment):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class TestPixelObservationWrapper(gym.Wrapper):
|
||||
@pytest.mark.parametrize("pixels_only", (True, False))
|
||||
def test_dict_observation(self, pixels_only):
|
||||
pixel_key = "rgb"
|
||||
@pytest.mark.parametrize("pixels_only", (True, False))
|
||||
def test_dict_observation(pixels_only):
|
||||
pixel_key = "rgb"
|
||||
|
||||
env = FakeDictObservationEnvironment()
|
||||
env = FakeDictObservationEnvironment()
|
||||
|
||||
# Make sure we are testing the right environment for the test.
|
||||
observation_space = env.observation_space
|
||||
assert isinstance(observation_space, spaces.Dict)
|
||||
# Make sure we are testing the right environment for the test.
|
||||
observation_space = env.observation_space
|
||||
assert isinstance(observation_space, spaces.Dict)
|
||||
|
||||
width, height = (320, 240)
|
||||
width, height = (320, 240)
|
||||
|
||||
# The wrapper should only add one observation.
|
||||
wrapped_env = PixelObservationWrapper(
|
||||
env,
|
||||
pixel_keys=(pixel_key,),
|
||||
pixels_only=pixels_only,
|
||||
render_kwargs={pixel_key: {"width": width, "height": height}},
|
||||
# The wrapper should only add one observation.
|
||||
wrapped_env = PixelObservationWrapper(
|
||||
env,
|
||||
pixel_keys=(pixel_key,),
|
||||
pixels_only=pixels_only,
|
||||
render_kwargs={pixel_key: {"width": width, "height": height}},
|
||||
)
|
||||
|
||||
assert isinstance(wrapped_env.observation_space, spaces.Dict)
|
||||
|
||||
if pixels_only:
|
||||
assert len(wrapped_env.observation_space.spaces) == 1
|
||||
assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
|
||||
else:
|
||||
assert (
|
||||
len(wrapped_env.observation_space.spaces)
|
||||
== len(observation_space.spaces) + 1
|
||||
)
|
||||
expected_keys = list(observation_space.spaces.keys()) + [pixel_key]
|
||||
assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys
|
||||
|
||||
assert isinstance(wrapped_env.observation_space, spaces.Dict)
|
||||
# Check that the added space item is consistent with the added observation.
|
||||
observation = wrapped_env.reset()
|
||||
rgb_observation = observation[pixel_key]
|
||||
|
||||
if pixels_only:
|
||||
assert len(wrapped_env.observation_space.spaces) == 1
|
||||
assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
|
||||
else:
|
||||
assert (
|
||||
len(wrapped_env.observation_space.spaces)
|
||||
== len(observation_space.spaces) + 1
|
||||
)
|
||||
expected_keys = list(observation_space.spaces.keys()) + [pixel_key]
|
||||
assert list(wrapped_env.observation_space.spaces.keys()) == expected_keys
|
||||
assert rgb_observation.shape == (height, width, 3)
|
||||
assert rgb_observation.dtype == np.uint8
|
||||
|
||||
# Check that the added space item is consistent with the added observation.
|
||||
observation = wrapped_env.reset()
|
||||
rgb_observation = observation[pixel_key]
|
||||
|
||||
assert rgb_observation.shape == (height, width, 3)
|
||||
assert rgb_observation.dtype == np.uint8
|
||||
@pytest.mark.parametrize("pixels_only", (True, False))
|
||||
def test_single_array_observation(pixels_only):
|
||||
pixel_key = "depth"
|
||||
|
||||
@pytest.mark.parametrize("pixels_only", (True, False))
|
||||
def test_single_array_observation(self, pixels_only):
|
||||
pixel_key = "depth"
|
||||
env = FakeArrayObservationEnvironment()
|
||||
observation_space = env.observation_space
|
||||
assert isinstance(observation_space, spaces.Box)
|
||||
|
||||
env = FakeArrayObservationEnvironment()
|
||||
observation_space = env.observation_space
|
||||
assert isinstance(observation_space, spaces.Box)
|
||||
wrapped_env = PixelObservationWrapper(
|
||||
env, pixel_keys=(pixel_key,), pixels_only=pixels_only
|
||||
)
|
||||
wrapped_env.observation_space = wrapped_env.observation_space
|
||||
assert isinstance(wrapped_env.observation_space, spaces.Dict)
|
||||
|
||||
wrapped_env = PixelObservationWrapper(
|
||||
env, pixel_keys=(pixel_key,), pixels_only=pixels_only
|
||||
)
|
||||
wrapped_env.observation_space = wrapped_env.observation_space
|
||||
assert isinstance(wrapped_env.observation_space, spaces.Dict)
|
||||
if pixels_only:
|
||||
assert len(wrapped_env.observation_space.spaces) == 1
|
||||
assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
|
||||
else:
|
||||
assert len(wrapped_env.observation_space.spaces) == 2
|
||||
assert list(wrapped_env.observation_space.spaces.keys()) == [
|
||||
STATE_KEY,
|
||||
pixel_key,
|
||||
]
|
||||
|
||||
if pixels_only:
|
||||
assert len(wrapped_env.observation_space.spaces) == 1
|
||||
assert list(wrapped_env.observation_space.spaces.keys()) == [pixel_key]
|
||||
else:
|
||||
assert len(wrapped_env.observation_space.spaces) == 2
|
||||
assert list(wrapped_env.observation_space.spaces.keys()) == [
|
||||
STATE_KEY,
|
||||
pixel_key,
|
||||
]
|
||||
observation = wrapped_env.reset()
|
||||
depth_observation = observation[pixel_key]
|
||||
|
||||
observation = wrapped_env.reset()
|
||||
depth_observation = observation[pixel_key]
|
||||
assert depth_observation.shape == (32, 32, 3)
|
||||
assert depth_observation.dtype == np.uint8
|
||||
|
||||
assert depth_observation.shape == (32, 32, 3)
|
||||
assert depth_observation.dtype == np.uint8
|
||||
|
||||
if not pixels_only:
|
||||
assert isinstance(observation[STATE_KEY], np.ndarray)
|
||||
if not pixels_only:
|
||||
assert isinstance(observation[STATE_KEY], np.ndarray)
|
||||
|
Reference in New Issue
Block a user