Remove pytest.importorskip (#2976)

This commit is contained in:
Mark Towers
2022-07-23 15:38:52 +01:00
committed by GitHub
parent 8461425286
commit 3c60ae97d1
5 changed files with 134 additions and 114 deletions

View File

@@ -24,12 +24,18 @@ def basic_reset_fn(
return self.observation_space.sample() return self.observation_space.sample()
def basic_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: def new_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
"""A basic step function that will pass the environment check using random actions from the observation space.""" """A step function that follows the new step api that will pass the environment check using random actions from the observation space."""
return self.observation_space.sample(), 0, False, False, {} return self.observation_space.sample(), 0, False, False, {}
def old_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
"""A step function that follows the old step api that will pass the environment check using random actions from the observation space."""
return self.observation_space.sample(), 0, False, {}
def basic_render_fn(self): def basic_render_fn(self):
"""Basic render fn that does nothing."""
pass pass
@@ -42,7 +48,7 @@ class GenericTestEnv(gym.Env):
action_space: spaces.Space = spaces.Box(0, 1, (1,)), action_space: spaces.Space = spaces.Box(0, 1, (1,)),
observation_space: spaces.Space = spaces.Box(0, 1, (1,)), observation_space: spaces.Space = spaces.Box(0, 1, (1,)),
reset_fn: callable = basic_reset_fn, reset_fn: callable = basic_reset_fn,
step_fn: callable = basic_step_fn, step_fn: callable = new_step_fn,
render_fn: callable = basic_render_fn, render_fn: callable = basic_render_fn,
render_modes: Optional[List[str]] = None, render_modes: Optional[List[str]] = None,
render_fps: Optional[int] = None, render_fps: Optional[int] = None,

View File

@@ -1,90 +1,124 @@
import numpy as np import numpy as np
import pytest import pytest
import gym from gym.spaces import Box, Discrete
from gym.wrappers import AtariPreprocessing from gym.wrappers import AtariPreprocessing
from tests.testing_env import GenericTestEnv, old_step_fn
pytest.importorskip("gym.envs.atari")
@pytest.fixture(scope="module") class AleTesting:
def env_fn(): """A testing implementation for the ALE object in atari games."""
return lambda: gym.make("PongNoFrameskip-v4", disable_env_checker=True)
grayscale_obs_space = Box(low=0, high=255, shape=(210, 160), dtype=np.uint8, seed=1)
rgb_obs_space = Box(low=0, high=255, shape=(210, 160, 3), dtype=np.uint8, seed=1)
def lives(self) -> int:
"""Returns the number of lives in the atari game."""
return 1
def getScreenGrayscale(self, buffer: np.ndarray):
"""Updates the buffer with a random grayscale observation."""
buffer[...] = self.grayscale_obs_space.sample()
def getScreenRGB(self, buffer: np.ndarray):
"""Updates the buffer with a random rgb observation."""
buffer[...] = self.rgb_obs_space.sample()
def test_atari_preprocessing_grayscale(env_fn): class AtariTestingEnv(GenericTestEnv):
import cv2 """A testing environment to replicate the atari (ale-py) environments."""
env1 = env_fn() def __init__(self):
env2 = AtariPreprocessing( super().__init__(
env_fn(), screen_size=84, grayscale_obs=True, frame_skip=1, noop_max=0 observation_space=Box(
low=0, high=255, shape=(210, 160, 3), dtype=np.uint8, seed=1
),
action_space=Discrete(3, seed=1),
step_fn=old_step_fn,
) )
env3 = AtariPreprocessing( self.ale = AleTesting()
env_fn(), screen_size=84, grayscale_obs=False, frame_skip=1, noop_max=0
) def get_action_meanings(self):
env4 = AtariPreprocessing( """Returns the meanings of each of the actions available to the agent. First index must be 'NOOP'."""
env_fn(), return ["NOOP", "UP", "DOWN"]
@pytest.mark.parametrize(
"env, obs_shape",
[
(AtariTestingEnv(), (210, 160, 3)),
(
AtariPreprocessing(
AtariTestingEnv(),
screen_size=84,
grayscale_obs=True,
frame_skip=1,
noop_max=0,
),
(84, 84),
),
(
AtariPreprocessing(
AtariTestingEnv(),
screen_size=84,
grayscale_obs=False,
frame_skip=1,
noop_max=0,
),
(84, 84, 3),
),
(
AtariPreprocessing(
AtariTestingEnv(),
screen_size=84, screen_size=84,
grayscale_obs=True, grayscale_obs=True,
frame_skip=1, frame_skip=1,
noop_max=0, noop_max=0,
grayscale_newaxis=True, grayscale_newaxis=True,
),
(84, 84, 1),
),
],
) )
obs1 = env1.reset(seed=0) def test_atari_preprocessing_grayscale(env, obs_shape):
obs2 = env2.reset(seed=0) assert env.observation_space.shape == obs_shape
obs3 = env3.reset(seed=0)
obs4 = env4.reset(seed=0)
assert env1.observation_space.shape == (210, 160, 3)
assert env2.observation_space.shape == (84, 84)
assert env3.observation_space.shape == (84, 84, 3)
assert env4.observation_space.shape == (84, 84, 1)
assert obs1.shape == (210, 160, 3)
assert obs2.shape == (84, 84)
assert obs3.shape == (84, 84, 3)
assert obs4.shape == (84, 84, 1)
assert np.allclose(obs3, cv2.resize(obs1, (84, 84), interpolation=cv2.INTER_AREA))
obs3_gray = cv2.cvtColor(obs3, cv2.COLOR_RGB2GRAY)
# the edges of the numbers do not render quite the same in the grayscale, so we ignore them
assert np.allclose(obs2[10:38], obs3_gray[10:38])
# the paddle also do not render quite the same
assert np.allclose(obs2[44:], obs3_gray[44:])
# now add a channel axis and re-test
obs3_gray = obs3_gray.reshape(84, 84, 1)
assert np.allclose(obs4[10:38], obs3_gray[10:38])
assert np.allclose(obs4[44:], obs3_gray[44:])
env1.close() # It is not possible to test the outputs as we are not using actual observations.
env2.close() # todo: update when ale-py is compatible with the ci
env3.close()
env4.close() obs = env.reset(seed=0)
assert obs in env.observation_space
obs, _ = env.reset(seed=0, return_info=True)
assert obs in env.observation_space
obs, _, _, _ = env.step(env.action_space.sample())
assert obs in env.observation_space
env.close()
def test_atari_preprocessing_scale(env_fn): @pytest.mark.parametrize("grayscale", [True, False])
@pytest.mark.parametrize("scaled", [True, False])
def test_atari_preprocessing_scale(grayscale, scaled, max_test_steps=10):
# arbitrarily chosen number for stepping into env. and ensuring all observations are in the required range # arbitrarily chosen number for stepping into env. and ensuring all observations are in the required range
max_test_steps = 10
for grayscale in [True, False]:
for scaled in [True, False]:
env = AtariPreprocessing( env = AtariPreprocessing(
env_fn(), AtariTestingEnv(),
screen_size=84, screen_size=84,
grayscale_obs=grayscale, grayscale_obs=grayscale,
scale_obs=scaled, scale_obs=scaled,
frame_skip=1, frame_skip=1,
noop_max=0, noop_max=0,
) )
obs = env.reset().flatten()
done, step_i = False, 0
max_obs = 1 if scaled else 255
assert (0 <= obs).all() and (
obs <= max_obs
).all(), f"Obs. must be in range [0,{max_obs}]"
while not done or step_i <= max_test_steps:
obs, _, done, _ = env.step(env.action_space.sample())
obs = obs.flatten()
assert (0 <= obs).all() and (
obs <= max_obs
).all(), f"Obs. must be in range [0,{max_obs}]"
step_i += 1
obs = env.reset()
max_obs = 1 if scaled else 255
assert np.all(0 <= obs) and np.all(obs <= max_obs)
done, step_i = False, 0
while not done and step_i <= max_test_steps:
obs, _, done, _ = env.step(env.action_space.sample())
assert np.all(0 <= obs) and np.all(obs <= max_obs)
step_i += 1
env.close() env.close()

View File

@@ -10,10 +10,7 @@ except ImportError:
lz4 = None lz4 = None
pytest.importorskip("gym.envs.atari") @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1", "CarRacing-v2"])
@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1", "Pong-v0"])
@pytest.mark.parametrize("num_stack", [2, 3, 4]) @pytest.mark.parametrize("num_stack", [2, 3, 4])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"lz4_compress", "lz4_compress",
@@ -42,8 +39,13 @@ def test_frame_stack(env_id, num_stack, lz4_compress):
for _ in range(num_stack**2): for _ in range(num_stack**2):
action = env.action_space.sample() action = env.action_space.sample()
dup_obs, _, _, _ = dup.step(action) dup_obs, _, dup_done, _ = dup.step(action)
obs, _, _, _ = env.step(action) obs, _, done, _ = env.step(action)
assert dup_done == done
assert np.allclose(obs[-1], dup_obs) assert np.allclose(obs[-1], dup_obs)
if done:
break
assert len(obs) == num_stack assert len(obs) == num_stack

View File

@@ -1,44 +1,26 @@
import numpy as np
import pytest import pytest
import gym import gym
from gym import spaces from gym import spaces
from gym.wrappers import AtariPreprocessing, GrayScaleObservation from gym.wrappers import GrayScaleObservation
pytest.importorskip("gym.envs.atari")
pytest.importorskip("cv2")
@pytest.mark.parametrize( @pytest.mark.parametrize("env_id", ["CarRacing-v2"])
"env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"]
)
@pytest.mark.parametrize("keep_dim", [True, False]) @pytest.mark.parametrize("keep_dim", [True, False])
def test_gray_scale_observation(env_id, keep_dim): def test_gray_scale_observation(env_id, keep_dim):
gray_env = AtariPreprocessing( rgb_env = gym.make(env_id, disable_env_checker=True)
gym.make(env_id, disable_env_checker=True), screen_size=84, grayscale_obs=True
)
rgb_env = AtariPreprocessing(
gym.make(env_id, disable_env_checker=True), screen_size=84, grayscale_obs=False
)
wrapped_env = GrayScaleObservation(rgb_env, keep_dim=keep_dim)
assert isinstance(rgb_env.observation_space, spaces.Box) assert isinstance(rgb_env.observation_space, spaces.Box)
assert len(rgb_env.observation_space.shape) == 3
assert rgb_env.observation_space.shape[-1] == 3 assert rgb_env.observation_space.shape[-1] == 3
wrapped_env = GrayScaleObservation(rgb_env, keep_dim=keep_dim)
assert isinstance(wrapped_env.observation_space, spaces.Box) assert isinstance(wrapped_env.observation_space, spaces.Box)
seed = 0
gray_obs = gray_env.reset(seed=seed)
wrapped_obs = wrapped_env.reset(seed=seed)
if keep_dim: if keep_dim:
assert len(wrapped_env.observation_space.shape) == 3
assert wrapped_env.observation_space.shape[-1] == 1 assert wrapped_env.observation_space.shape[-1] == 1
assert len(wrapped_obs.shape) == 3
wrapped_obs = wrapped_obs.squeeze(-1)
else: else:
assert len(wrapped_env.observation_space.shape) == 2 assert len(wrapped_env.observation_space.shape) == 2
assert len(wrapped_obs.shape) == 2
# ALE gray scale is slightly different, but no more than by one shade wrapped_obs = wrapped_env.reset()
assert np.allclose(gray_obs.astype("int32"), wrapped_obs.astype("int32"), atol=1) assert wrapped_obs in wrapped_env.observation_space

View File

@@ -4,12 +4,8 @@ import gym
from gym import spaces from gym import spaces
from gym.wrappers import ResizeObservation from gym.wrappers import ResizeObservation
pytest.importorskip("gym.envs.atari")
@pytest.mark.parametrize("env_id", ["CarRacing-v2"])
@pytest.mark.parametrize(
"env_id", ["PongNoFrameskip-v0", "SpaceInvadersNoFrameskip-v0"]
)
@pytest.mark.parametrize("shape", [16, 32, (8, 5), [10, 7]]) @pytest.mark.parametrize("shape", [16, 32, (8, 5), [10, 7]])
def test_resize_observation(env_id, shape): def test_resize_observation(env_id, shape):
env = gym.make(env_id, disable_env_checker=True) env = gym.make(env_id, disable_env_checker=True)