diff --git a/tests/testing_env.py b/tests/testing_env.py index ecac8b737..bc03c35be 100644 --- a/tests/testing_env.py +++ b/tests/testing_env.py @@ -24,12 +24,18 @@ def basic_reset_fn( return self.observation_space.sample() -def basic_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: - """A basic step function that will pass the environment check using random actions from the observation space.""" +def new_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: + """A step function that follows the new step api that will pass the environment check using random actions from the observation space.""" return self.observation_space.sample(), 0, False, False, {} +def old_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, dict]: + """A step function that follows the old step api that will pass the environment check using random actions from the observation space.""" + return self.observation_space.sample(), 0, False, {} + + def basic_render_fn(self): + """Basic render fn that does nothing.""" pass @@ -42,7 +48,7 @@ class GenericTestEnv(gym.Env): action_space: spaces.Space = spaces.Box(0, 1, (1,)), observation_space: spaces.Space = spaces.Box(0, 1, (1,)), reset_fn: callable = basic_reset_fn, - step_fn: callable = basic_step_fn, + step_fn: callable = new_step_fn, render_fn: callable = basic_render_fn, render_modes: Optional[List[str]] = None, render_fps: Optional[int] = None, diff --git a/tests/wrappers/test_atari_preprocessing.py b/tests/wrappers/test_atari_preprocessing.py index 845d7b25c..e083d3f0e 100644 --- a/tests/wrappers/test_atari_preprocessing.py +++ b/tests/wrappers/test_atari_preprocessing.py @@ -1,90 +1,124 @@ import numpy as np import pytest -import gym +from gym.spaces import Box, Discrete from gym.wrappers import AtariPreprocessing - -pytest.importorskip("gym.envs.atari") +from tests.testing_env import GenericTestEnv, old_step_fn -@pytest.fixture(scope="module") -def env_fn(): - return lambda: gym.make("PongNoFrameskip-v4", disable_env_checker=True) +class AleTesting: + """A testing implementation for the ALE object in atari games.""" + + grayscale_obs_space = Box(low=0, high=255, shape=(210, 160), dtype=np.uint8, seed=1) + rgb_obs_space = Box(low=0, high=255, shape=(210, 160, 3), dtype=np.uint8, seed=1) + + def lives(self) -> int: + """Returns the number of lives in the atari game.""" + return 1 + + def getScreenGrayscale(self, buffer: np.ndarray): + """Updates the buffer with a random grayscale observation.""" + buffer[...] = self.grayscale_obs_space.sample() + + def getScreenRGB(self, buffer: np.ndarray): + """Updates the buffer with a random rgb observation.""" + buffer[...] = self.rgb_obs_space.sample() -def test_atari_preprocessing_grayscale(env_fn): - import cv2 +class AtariTestingEnv(GenericTestEnv): + """A testing environment to replicate the atari (ale-py) environments.""" - env1 = env_fn() - env2 = AtariPreprocessing( - env_fn(), screen_size=84, grayscale_obs=True, frame_skip=1, noop_max=0 - ) - env3 = AtariPreprocessing( - env_fn(), screen_size=84, grayscale_obs=False, frame_skip=1, noop_max=0 - ) - env4 = AtariPreprocessing( - env_fn(), - screen_size=84, - grayscale_obs=True, - frame_skip=1, - noop_max=0, - grayscale_newaxis=True, - ) - obs1 = env1.reset(seed=0) - obs2 = env2.reset(seed=0) - obs3 = env3.reset(seed=0) - obs4 = env4.reset(seed=0) - assert env1.observation_space.shape == (210, 160, 3) - assert env2.observation_space.shape == (84, 84) - assert env3.observation_space.shape == (84, 84, 3) - assert env4.observation_space.shape == (84, 84, 1) - assert obs1.shape == (210, 160, 3) - assert obs2.shape == (84, 84) - assert obs3.shape == (84, 84, 3) - assert obs4.shape == (84, 84, 1) - assert np.allclose(obs3, cv2.resize(obs1, (84, 84), interpolation=cv2.INTER_AREA)) - obs3_gray = cv2.cvtColor(obs3, cv2.COLOR_RGB2GRAY) - # the edges of the numbers do not render quite the same in the grayscale, so we ignore them - assert np.allclose(obs2[10:38], obs3_gray[10:38]) - # the paddle also do not render quite the same - assert np.allclose(obs2[44:], obs3_gray[44:]) - # now add a channel axis and re-test - obs3_gray = obs3_gray.reshape(84, 84, 1) - assert np.allclose(obs4[10:38], obs3_gray[10:38]) - assert np.allclose(obs4[44:], obs3_gray[44:]) + def __init__(self): + super().__init__( + observation_space=Box( + low=0, high=255, shape=(210, 160, 3), dtype=np.uint8, seed=1 + ), + action_space=Discrete(3, seed=1), + step_fn=old_step_fn, + ) + self.ale = AleTesting() - env1.close() - env2.close() - env3.close() - env4.close() + def get_action_meanings(self): + """Returns the meanings of each of the actions available to the agent. First index must be 'NOOP'.""" + return ["NOOP", "UP", "DOWN"] -def test_atari_preprocessing_scale(env_fn): - # arbitrarily chosen number for stepping into env. and ensuring all observations are in the required range - max_test_steps = 10 - - for grayscale in [True, False]: - for scaled in [True, False]: - env = AtariPreprocessing( - env_fn(), +@pytest.mark.parametrize( + "env, obs_shape", + [ + (AtariTestingEnv(), (210, 160, 3)), + ( + AtariPreprocessing( + AtariTestingEnv(), screen_size=84, - grayscale_obs=grayscale, - scale_obs=scaled, + grayscale_obs=True, frame_skip=1, noop_max=0, - ) - obs = env.reset().flatten() - done, step_i = False, 0 - max_obs = 1 if scaled else 255 - assert (0 <= obs).all() and ( - obs <= max_obs - ).all(), f"Obs. must be in range [0,{max_obs}]" - while not done or step_i <= max_test_steps: - obs, _, done, _ = env.step(env.action_space.sample()) - obs = obs.flatten() - assert (0 <= obs).all() and ( - obs <= max_obs - ).all(), f"Obs. must be in range [0,{max_obs}]" - step_i += 1 + ), + (84, 84), + ), + ( + AtariPreprocessing( + AtariTestingEnv(), + screen_size=84, + grayscale_obs=False, + frame_skip=1, + noop_max=0, + ), + (84, 84, 3), + ), + ( + AtariPreprocessing( + AtariTestingEnv(), + screen_size=84, + grayscale_obs=True, + frame_skip=1, + noop_max=0, + grayscale_newaxis=True, + ), + (84, 84, 1), + ), + ], +) +def test_atari_preprocessing_grayscale(env, obs_shape): + assert env.observation_space.shape == obs_shape - env.close() + # It is not possible to test the outputs as we are not using actual observations. + # todo: update when ale-py is compatible with the ci + + obs = env.reset(seed=0) + assert obs in env.observation_space + obs, _ = env.reset(seed=0, return_info=True) + assert obs in env.observation_space + + obs, _, _, _ = env.step(env.action_space.sample()) + assert obs in env.observation_space + + env.close() + + +@pytest.mark.parametrize("grayscale", [True, False]) +@pytest.mark.parametrize("scaled", [True, False]) +def test_atari_preprocessing_scale(grayscale, scaled, max_test_steps=10): + # arbitrarily chosen number for stepping into env. and ensuring all observations are in the required range + env = AtariPreprocessing( + AtariTestingEnv(), + screen_size=84, + grayscale_obs=grayscale, + scale_obs=scaled, + frame_skip=1, + noop_max=0, + ) + + obs = env.reset() + + max_obs = 1 if scaled else 255 + assert np.all(0 <= obs) and np.all(obs <= max_obs) + + done, step_i = False, 0 + while not done and step_i <= max_test_steps: + obs, _, done, _ = env.step(env.action_space.sample()) + assert np.all(0 <= obs) and np.all(obs <= max_obs) + + step_i += 1 + env.close() diff --git a/tests/wrappers/test_frame_stack.py b/tests/wrappers/test_frame_stack.py index 826790f86..8c4ed0664 100644 --- a/tests/wrappers/test_frame_stack.py +++ b/tests/wrappers/test_frame_stack.py @@ -10,10 +10,7 @@ except ImportError: lz4 = None -pytest.importorskip("gym.envs.atari") - - -@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1", "Pong-v0"]) +@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1", "CarRacing-v2"]) @pytest.mark.parametrize("num_stack", [2, 3, 4]) @pytest.mark.parametrize( "lz4_compress", @@ -42,8 +39,13 @@ def test_frame_stack(env_id, num_stack, lz4_compress): for _ in range(num_stack**2): action = env.action_space.sample() - dup_obs, _, _, _ = dup.step(action) - obs, _, _, _ = env.step(action) + dup_obs, _, dup_done, _ = dup.step(action) + obs, _, done, _ = env.step(action) + + assert dup_done == done assert np.allclose(obs[-1], dup_obs) + if done: + break + assert len(obs) == num_stack diff --git a/tests/wrappers/test_gray_scale_observation.py b/tests/wrappers/test_gray_scale_observation.py index 58f878952..be63fc0d2 100644 --- a/tests/wrappers/test_gray_scale_observation.py +++ b/tests/wrappers/test_gray_scale_observation.py @@ -1,44 +1,26 @@ -import numpy as np import pytest import gym from gym import spaces -from gym.wrappers import AtariPreprocessing, GrayScaleObservation - -pytest.importorskip("gym.envs.atari") -pytest.importorskip("cv2") +from gym.wrappers import GrayScaleObservation -@pytest.mark.parametrize( - "env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"] -) +@pytest.mark.parametrize("env_id", ["CarRacing-v2"]) @pytest.mark.parametrize("keep_dim", [True, False]) def test_gray_scale_observation(env_id, keep_dim): - gray_env = AtariPreprocessing( - gym.make(env_id, disable_env_checker=True), screen_size=84, grayscale_obs=True - ) - rgb_env = AtariPreprocessing( - gym.make(env_id, disable_env_checker=True), screen_size=84, grayscale_obs=False - ) - wrapped_env = GrayScaleObservation(rgb_env, keep_dim=keep_dim) + rgb_env = gym.make(env_id, disable_env_checker=True) assert isinstance(rgb_env.observation_space, spaces.Box) + assert len(rgb_env.observation_space.shape) == 3 assert rgb_env.observation_space.shape[-1] == 3 + wrapped_env = GrayScaleObservation(rgb_env, keep_dim=keep_dim) assert isinstance(wrapped_env.observation_space, spaces.Box) - - seed = 0 - - gray_obs = gray_env.reset(seed=seed) - wrapped_obs = wrapped_env.reset(seed=seed) - if keep_dim: + assert len(wrapped_env.observation_space.shape) == 3 assert wrapped_env.observation_space.shape[-1] == 1 - assert len(wrapped_obs.shape) == 3 - wrapped_obs = wrapped_obs.squeeze(-1) else: assert len(wrapped_env.observation_space.shape) == 2 - assert len(wrapped_obs.shape) == 2 - # ALE gray scale is slightly different, but no more than by one shade - assert np.allclose(gray_obs.astype("int32"), wrapped_obs.astype("int32"), atol=1) + wrapped_obs = wrapped_env.reset() + assert wrapped_obs in wrapped_env.observation_space diff --git a/tests/wrappers/test_resize_observation.py b/tests/wrappers/test_resize_observation.py index 62b3a3d73..b0553df60 100644 --- a/tests/wrappers/test_resize_observation.py +++ b/tests/wrappers/test_resize_observation.py @@ -4,12 +4,8 @@ import gym from gym import spaces from gym.wrappers import ResizeObservation -pytest.importorskip("gym.envs.atari") - -@pytest.mark.parametrize( - "env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"] -) +@pytest.mark.parametrize("env_id", ["CarRacing-v2"]) @pytest.mark.parametrize("shape", [16, 32, (8, 5), [10, 7]]) def test_resize_observation(env_id, shape): env = gym.make(env_id, disable_env_checker=True)