Full type hinting (#2942)

* Allows a new RNG to be generated with seed=-1 and updated env_checker to fix bug if environment doesn't use np_random in reset

* Revert "fixed `gym.vector.make` where the checker was being applied in the opposite case than was intended to (#2871)"

This reverts commit 519dfd9117.

* Remove bad pushed commits

* Fixed spelling in core.py

* Pins pytest to the last py 3.6 version

* Allow Box automatic scalar shape

* Add test box and change default from () to (1,)

* update Box shape inference with more strict checking

* Update the box shape and add check on the custom Box shape

* Removed incorrect shape type and assert shape code

* Update the Box and associated tests

* Remove all folders and files from pyright exclude

* Revert issues

* Push RedTachyon code review

* Add Python Platform

* Remove play from pyright check

* Fixed CI issues

* remove mujoco env type hinting

* Fixed pixel observation test

* Added some new type hints

* Fixed CI errors

* Fixed CI errors

* Remove play.py from exlucde pyright

* Fixed pyright issues
This commit is contained in:
Mark Towers
2022-07-04 18:19:25 +01:00
committed by GitHub
parent 9e66399b4e
commit 2ede09074f
61 changed files with 352 additions and 286 deletions

View File

@@ -1,6 +1,7 @@
"""Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper."""
import sys
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
@@ -17,6 +18,9 @@ from gym.logger import deprecation, warn
from gym.utils import seeding
from gym.utils.seeding import RandomNumberGenerator
if TYPE_CHECKING:
from gym.envs.registration import EnvSpec
if sys.version_info[0:2] == (3, 6):
warn(
"Gym minimally supports python 3.6 as the python foundation not longer supports the version, please update your version to 3.7+"
@@ -106,7 +110,7 @@ class Env(Generic[ObsType, ActType], metaclass=decorator):
metadata = {"render_modes": []}
render_mode = None # define render_mode if your environment supports rendering
reward_range = (-float("inf"), float("inf"))
spec = None
spec: "EnvSpec" = None
# Set these in ALL subclasses
action_space: spaces.Space[ActType]

View File

@@ -1,7 +1,7 @@
__credits__ = ["Andrea PIERRÉ"]
import math
from typing import Optional
from typing import TYPE_CHECKING, List, Optional
import numpy as np
@@ -25,6 +25,9 @@ except ImportError:
raise DependencyNotInstalled("box2D is not installed, run `pip install gym[box2d]`")
if TYPE_CHECKING:
import pygame
FPS = 50
SCALE = 30.0 # affects how fast-paced the game is, forces should be adjusted as well
@@ -170,8 +173,8 @@ class BipedalWalker(gym.Env, EzPickle):
self.isopen = True
self.world = Box2D.b2World()
self.terrain = None
self.hull = None
self.terrain: List[Box2D.b2Body] = []
self.hull: Optional[Box2D.b2Body] = None
self.prev_shaping = None
@@ -256,7 +259,7 @@ class BipedalWalker(gym.Env, EzPickle):
self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
self.screen = None
self.screen: Optional[pygame.Surface] = None
self.clock = None
def _destroy(self):
@@ -283,6 +286,9 @@ class BipedalWalker(gym.Env, EzPickle):
self.terrain = []
self.terrain_x = []
self.terrain_y = []
stair_steps, stair_width, stair_height = 0, 0, 0
original_y = 0
for i in range(TERRAIN_LENGTH):
x = i * TERRAIN_STEP
self.terrain_x.append(x)
@@ -448,8 +454,8 @@ class BipedalWalker(gym.Env, EzPickle):
(self.np_random.uniform(-INITIAL_RANDOM, INITIAL_RANDOM), 0), True
)
self.legs = []
self.joints = []
self.legs: List[Box2D.b2Body] = []
self.joints: List[Box2D.b2RevoluteJoint] = []
for i in [-1, +1]:
leg = self.world.CreateDynamicBody(
position=(init_x, init_y - LEG_H / 2 - LEG_DOWN),
@@ -514,6 +520,8 @@ class BipedalWalker(gym.Env, EzPickle):
return self.step(np.array([0, 0, 0, 0]))[0], {}
def step(self, action: np.ndarray):
assert self.hull is not None
# self.hull.ApplyForceToCenter((0, 20), True) -- Uncomment this to receive a bit of stability help
control_speed = False # Should be easier as well
if control_speed:
@@ -737,6 +745,7 @@ class BipedalWalker(gym.Env, EzPickle):
self.surf = pygame.transform.flip(self.surf, False, True)
if mode == "human":
assert self.screen is not None
self.screen.blit(self.surf, (-self.scroll * SCALE, 0))
pygame.event.pump()
self.clock.tick(self.metadata["render_fps"])

View File

@@ -9,6 +9,7 @@ Created by Oleg Klimov
import math
import Box2D
import numpy as np
from gym.error import DependencyNotInstalled
@@ -48,8 +49,8 @@ MUD_COLOR = (102, 102, 0)
class Car:
def __init__(self, world, init_angle, init_x, init_y):
self.world = world
self.hull = self.world.CreateDynamicBody(
self.world: Box2D.b2World = world
self.hull: Box2D.b2Body = self.world.CreateDynamicBody(
position=(init_x, init_y),
angle=init_angle,
fixtures=[

View File

@@ -197,14 +197,14 @@ class CarRacing(gym.Env, EzPickle):
self.contactListener_keepref = FrictionDetector(self, lap_complete_percent)
self.world = Box2D.b2World((0, 0), contactListener=self.contactListener_keepref)
self.screen = None
self.screen: Optional[pygame.Surface] = None
self.surf = None
self.clock = None
self.isopen = True
self.invisible_state_window = None
self.invisible_video_window = None
self.road = None
self.car = None
self.car: Optional[Car] = None
self.reward = 0.0
self.prev_reward = 0.0
self.verbose = verbose
@@ -237,6 +237,7 @@ class CarRacing(gym.Env, EzPickle):
for t in self.road:
self.world.DestroyBody(t)
self.road = []
assert self.car is not None
self.car.destroy()
def _init_colors(self):
@@ -502,6 +503,7 @@ class CarRacing(gym.Env, EzPickle):
return self.step(None)[0], {}
def step(self, action: Union[np.ndarray, int]):
assert self.car is not None
if action is not None:
if self.continuous:
self.car.steer(-action[0])
@@ -576,6 +578,7 @@ class CarRacing(gym.Env, EzPickle):
self.surf = pygame.Surface((WINDOW_W, WINDOW_H))
assert self.car is not None
# computing transformations
angle = -self.car.hull.angle
# Animating first second zoom.
@@ -608,6 +611,7 @@ class CarRacing(gym.Env, EzPickle):
if mode == "human":
pygame.event.pump()
self.clock.tick(self.metadata["render_fps"])
assert self.screen is not None
self.screen.fill(0)
self.screen.blit(self.surf, (0, 0))
pygame.display.flip()
@@ -682,6 +686,7 @@ class CarRacing(gym.Env, EzPickle):
((place + 0) * s, H - 2 * h),
]
assert self.car is not None
true_speed = np.sqrt(
np.square(self.car.hull.linearVelocity[0])
+ np.square(self.car.hull.linearVelocity[1])

View File

@@ -2,7 +2,7 @@ __credits__ = ["Andrea PIERRÉ"]
import math
import warnings
from typing import Optional
from typing import TYPE_CHECKING, Optional
import numpy as np
@@ -25,6 +25,11 @@ try:
except ImportError:
raise DependencyNotInstalled("box2d is not installed, run `pip install gym[box2d]`")
if TYPE_CHECKING:
import pygame
FPS = 50
SCALE = 30.0 # affects how fast-paced the game is, forces should be adjusted as well
@@ -215,7 +220,7 @@ class LunarLander(gym.Env, EzPickle):
self.wind_idx = np.random.randint(-9999, 9999)
self.torque_idx = np.random.randint(-9999, 9999)
self.screen = None
self.screen: pygame.Surface = None
self.clock = None
self.isopen = True
self.world = Box2D.b2World(gravity=(0, gravity))
@@ -427,6 +432,8 @@ class LunarLander(gym.Env, EzPickle):
self.world.DestroyBody(self.particles.pop(0))
def step(self, action):
assert self.lander is not None
# Update wind
assert self.lander is not None, "You forgot to call reset()"
if self.enable_wind and not (
@@ -604,10 +611,6 @@ class LunarLander(gym.Env, EzPickle):
if self.clock is None:
self.clock = pygame.time.Clock()
assert (
self.screen is not None
), "Something went wrong with pygame, there is no screen to render"
self.surf = pygame.Surface((VIEWPORT_W, VIEWPORT_H))
pygame.transform.scale(self.surf, (SCALE, SCALE))

View File

@@ -79,4 +79,5 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return self._get_obs()
def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.distance = self.model.stat.extent * 0.5

View File

@@ -173,6 +173,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation
def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value

View File

@@ -336,6 +336,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation
def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value

View File

@@ -60,4 +60,5 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return self._get_obs()
def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.distance = self.model.stat.extent * 0.5

View File

@@ -118,6 +118,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation
def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value

View File

@@ -231,6 +231,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation
def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value

View File

@@ -66,6 +66,7 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return self._get_obs()
def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.trackbodyid = 2
self.viewer.cam.distance = self.model.stat.extent * 0.75
self.viewer.cam.lookat[2] = 1.15

View File

@@ -163,6 +163,7 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation
def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value

View File

@@ -279,6 +279,7 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation
def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value

View File

@@ -92,6 +92,7 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return self._get_obs()
def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.trackbodyid = 1
self.viewer.cam.distance = self.model.stat.extent * 1.0
self.viewer.cam.lookat[2] = 2.0

View File

@@ -185,6 +185,7 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation
def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value

View File

@@ -358,6 +358,7 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation
def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value

View File

@@ -83,6 +83,7 @@ class HumanoidStandupEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return self._get_obs()
def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.trackbodyid = 1
self.viewer.cam.distance = self.model.stat.extent * 1.0
self.viewer.cam.lookat[2] = 0.8925

View File

@@ -254,6 +254,7 @@ class HumanoidStandupEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return self._get_obs()
def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.trackbodyid = 1
self.viewer.cam.distance = self.model.stat.extent * 1.0
self.viewer.cam.lookat[2] = 0.8925

View File

@@ -64,6 +64,7 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return self._get_obs()
def viewer_setup(self):
assert self.viewer is not None
v = self.viewer
v.cam.trackbodyid = 0
v.cam.distance = self.model.stat.extent * 0.5

View File

@@ -169,6 +169,7 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return self._get_obs()
def viewer_setup(self):
assert self.viewer is not None
v = self.viewer
v.cam.trackbodyid = 0
v.cam.distance = self.model.stat.extent * 0.5

View File

@@ -54,6 +54,6 @@ class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return np.concatenate([self.sim.data.qpos, self.sim.data.qvel]).ravel()
def viewer_setup(self):
v = self.viewer
v.cam.trackbodyid = 0
v.cam.distance = self.model.stat.extent
assert self.viewer is not None
self.viewer.cam.trackbodyid = 0
self.viewer.cam.distance = self.model.stat.extent

View File

@@ -130,6 +130,7 @@ class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return np.concatenate([self.data.qpos, self.data.qvel]).ravel()
def viewer_setup(self):
assert self.viewer is not None
v = self.viewer
v.cam.trackbodyid = 0
v.cam.distance = self.model.stat.extent

View File

@@ -47,6 +47,7 @@ class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.trackbodyid = -1
self.viewer.cam.distance = 4.0

View File

@@ -164,6 +164,7 @@ class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.trackbodyid = -1
self.viewer.cam.distance = 4.0

View File

@@ -43,6 +43,7 @@ class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.trackbodyid = 0
def reset_model(self):

View File

@@ -150,6 +150,7 @@ class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.trackbodyid = 0
def reset_model(self):

View File

@@ -119,6 +119,7 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation
def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value

View File

@@ -225,6 +225,7 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation
def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value

View File

@@ -60,6 +60,7 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return self._get_obs()
def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.trackbodyid = 2
self.viewer.cam.distance = self.model.stat.extent * 0.5
self.viewer.cam.lookat[2] = 1.15

View File

@@ -153,6 +153,7 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation
def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value

View File

@@ -277,6 +277,7 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation
def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value

View File

@@ -134,10 +134,12 @@ class Graph(Space):
assert (
num_edges >= 0
), f"The number of edges is expected to be greater than 0, actual mask: {num_edges}"
assert num_edges is not None
sampled_node_space = self._generate_sample_space(self.node_space, num_nodes)
sampled_edge_space = self._generate_sample_space(self.edge_space, num_edges)
assert sampled_node_space is not None
sampled_nodes = sampled_node_space.sample(node_space_mask)
sampled_edges = (
sampled_edge_space.sample(edge_space_mask)

View File

@@ -3,16 +3,23 @@ from collections import deque
from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import pygame
from pygame import Surface
from pygame.event import Event
from pygame.locals import VIDEORESIZE
import gym.error
from gym import Env, logger
from gym.core import ActType, ObsType
from gym.error import DependencyNotInstalled
from gym.logger import deprecation
try:
import pygame
from pygame import Surface
from pygame.event import Event
from pygame.locals import VIDEORESIZE
except ImportError:
raise gym.error.DependencyNotInstalled(
"Pygame is not installed, run `pip install gym[classic_control]`"
)
try:
import matplotlib
@@ -20,7 +27,7 @@ try:
import matplotlib.pyplot as plt
except ImportError:
logger.warn("Matplotlib is not installed, run `pip install gym[other]`")
matplotlib, plt = None, None
plt = None
class MissingKeysToAction(Exception):
@@ -33,7 +40,7 @@ class PlayableGame:
def __init__(
self,
env: Env,
keys_to_action: Optional[Dict[Tuple[int], int]] = None,
keys_to_action: Optional[Dict[Tuple[int, ...], int]] = None,
zoom: Optional[float] = None,
):
"""Wraps an environment with a dictionary of keyboard buttons to action and if to zoom in on the environment.
@@ -63,12 +70,14 @@ class PlayableGame:
f"{self.env.spec.id} does not have explicit key to action mapping, "
"please specify one manually"
)
assert isinstance(keys_to_action, dict)
relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), []))
return relevant_keys
def _get_video_size(self, zoom: Optional[float] = None) -> Tuple[int, int]:
# TODO: this needs to be updated when the render API change goes through
rendered = self.env.render(mode="rgb_array")
assert rendered is not None and isinstance(rendered, np.ndarray)
video_size = [rendered.shape[1], rendered.shape[0]]
if zoom is not None:
@@ -211,9 +220,9 @@ def play(
f"{env.spec.id} does not have explicit key to action mapping, "
"please specify one manually"
)
assert keys_to_action is not None
key_code_to_action = {}
for key_combination, action in keys_to_action.items():
key_code = tuple(
sorted(ord(key) if isinstance(key, str) else key for key in key_combination)
@@ -225,7 +234,7 @@ def play(
if fps is None:
fps = env.metadata.get("render_fps", 30)
done = True
done, obs = True, None
clock = pygame.time.Clock()
while game.running:
@@ -316,7 +325,7 @@ class PlayPlot:
for axis, name in zip(self.ax, plot_names):
axis.set_title(name)
self.t = 0
self.cur_plot = [None for _ in range(num_plots)]
self.cur_plot: List[Optional[plt.Axes]] = [None for _ in range(num_plots)]
self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)]
def callback(
@@ -352,4 +361,9 @@ class PlayPlot:
range(xmin, xmax), list(self.data[i]), c="blue"
)
self.ax[i].set_xlim(xmin, xmax)
if plt is None:
raise DependencyNotInstalled(
"matplotlib is not installed, run `pip install gym[other]`"
)
plt.pause(0.000001)

View File

@@ -569,7 +569,7 @@ class AsyncVectorEnv(VectorEnv):
num_errors = self.num_envs - sum(successes)
assert num_errors > 0
for _ in range(num_errors):
for i in range(num_errors):
index, exctype, value = self.error_queue.get()
logger.error(
f"Received the following error from Worker-{index}: {exctype.__name__}: {value}"
@@ -578,6 +578,7 @@ class AsyncVectorEnv(VectorEnv):
self.parent_pipes[index].close()
self.parent_pipes[index] = None
if i == num_errors - 1:
logger.error("Raising the last exception back to the main process.")
raise exctype(value)

View File

@@ -1,9 +1,10 @@
"""A synchronous vector environment."""
from copy import deepcopy
from typing import Any, Iterator, List, Optional, Sequence, Union
from typing import Any, Callable, Iterator, List, Optional, Sequence, Union
import numpy as np
from gym import Env
from gym.spaces import Space
from gym.vector.utils import concatenate, create_empty_array, iterate
from gym.vector.vector_env import VectorEnv
@@ -28,7 +29,7 @@ class SyncVectorEnv(VectorEnv):
def __init__(
self,
env_fns: Iterator[callable],
env_fns: Iterator[Callable[[], Env]],
observation_space: Space = None,
action_space: Space = None,
copy: bool = True,

View File

@@ -6,7 +6,8 @@ from gym.vector.utils.shared_memory import (
read_from_shared_memory,
write_to_shared_memory,
)
from gym.vector.utils.spaces import BaseGymSpaces, _BaseGymSpaces, batch_space, iterate
from gym.vector.utils.spaces import _BaseGymSpaces # pyright: reportPrivateUsage=false
from gym.vector.utils.spaces import BaseGymSpaces, batch_space, iterate
__all__ = [
"CloudpickleWrapper",
@@ -17,7 +18,6 @@ __all__ = [
"read_from_shared_memory",
"write_to_shared_memory",
"BaseGymSpaces",
"_BaseGymSpaces",
"batch_space",
"iterate",
]

View File

@@ -85,6 +85,7 @@ class VectorEnv(gym.Env):
Raises:
NotImplementedError: VectorEnv does not implement function
"""
raise NotImplementedError("VectorEnv does not implement function")
def reset(
self,

View File

@@ -2,13 +2,14 @@
import numpy as np
import gym
from gym.error import DependencyNotInstalled
from gym.spaces import Box
try:
import cv2
except ImportError:
cv2 = None
raise gym.error.DependencyNotInstalled(
"opencv-python package not installed, run `pip install gym[other]` to get dependencies for atari"
)
class AtariPreprocessing(gym.Wrapper):
@@ -60,10 +61,6 @@ class AtariPreprocessing(gym.Wrapper):
ValueError: Disable frame-skipping in the original env
"""
super().__init__(env)
if cv2 is None:
raise DependencyNotInstalled(
"opencv-python package not installed, run `pip install gym[other]` to get dependencies for atari"
)
assert frame_skip > 0
assert screen_size > 0
assert noop_max >= 0
@@ -87,6 +84,7 @@ class AtariPreprocessing(gym.Wrapper):
self.scale_obs = scale_obs
# buffer of most recent two observations for max pooling
assert isinstance(env.observation_space, Box)
if grayscale_obs:
self.obs_buffer = [
np.empty(env.observation_space.shape[:2], dtype=np.uint8),
@@ -114,7 +112,7 @@ class AtariPreprocessing(gym.Wrapper):
def step(self, action):
"""Applies the preprocessing for an :meth:`env.step`."""
total_reward = 0.0
total_reward, done, info = 0.0, False, {}
for t in range(self.frame_skip):
_, reward, done, info = self.env.step(action)

View File

@@ -32,8 +32,9 @@ class GrayScaleObservation(gym.ObservationWrapper):
self.keep_dim = keep_dim
assert (
len(env.observation_space.shape) == 3
and env.observation_space.shape[-1] == 3
isinstance(self.observation_space, Box)
and len(self.observation_space.shape) == 3
and self.observation_space.shape[-1] == 3
)
obs_shape = self.observation_space.shape[:2]

View File

@@ -88,13 +88,16 @@ class HumanRendering(gym.Wrapper):
"pygame is not installed, run `pip install gym[box2d]`"
)
if self.env.render_mode == "rgb_array":
last_rgb_array = self.env.render(**kwargs)[-1]
last_rgb_array = self.env.render(**kwargs)
assert isinstance(last_rgb_array, list)
last_rgb_array = last_rgb_array[-1]
elif self.env.render_mode == "single_rgb_array":
last_rgb_array = self.env.render(**kwargs)
else:
raise Exception(
"Wrapped environment must have mode 'rgb_array' or 'single_rgb_array'"
f"Wrapped environment must have mode 'rgb_array' or 'single_rgb_array', actual render mode: {self.env.render_mode}"
)
assert isinstance(last_rgb_array, np.ndarray)
if mode == "human":
rgb_array = np.transpose(last_rgb_array, axes=(1, 0, 2))

View File

@@ -148,7 +148,9 @@ class VideoRecorder: # TODO: remove with gym 1.0
)
self.output_frames_per_sec = self.backward_compatible_output_frames_per_sec
self.encoder = None # lazily start the process
self.encoder: Optional[
Union[TextEncoder, ImageEncoder]
] = None # lazily start the process
self.broken = False
# Dump metadata
@@ -387,7 +389,7 @@ class ImageEncoder:
InvalidFrame: Expects frame to have shape (w,h,3) or (w,h,4)
DependencyNotInstalled: Found neither the ffmpeg nor avconv executables.
"""
self.proc = None
self.proc: Optional[subprocess.Popen] = None
self.output_path = output_path
# Frame shape should be lines-first, so w and h are swapped
h, w, pixfmt = frame_shape
@@ -488,6 +490,7 @@ class ImageEncoder:
f"Your frame has data type {frame.dtype}, but we require uint8 (i.e. RGB values from 0-255)."
)
assert self.proc is not None and self.proc.stdin is not None
try:
self.proc.stdin.write(frame.tobytes())
except Exception:
@@ -496,6 +499,7 @@ class ImageEncoder:
def close(self):
"""Closes the Image encoder."""
assert self.proc is not None and self.proc.stdin is not None
self.proc.stdin.close()
ret = self.proc.wait()
if ret != 0:

View File

@@ -81,19 +81,20 @@ class NormalizeObservation(gym.core.Wrapper):
def reset(self, **kwargs):
"""Resets the environment and normalizes the observation."""
return_info = kwargs.get("return_info", False)
if return_info:
if kwargs.get("return_info", False):
obs, info = self.env.reset(**kwargs)
if self.is_vector_env:
return self.normalize(obs), info
else:
return self.normalize(np.array([obs]))[0], info
else:
obs = self.env.reset(**kwargs)
if self.is_vector_env:
obs = self.normalize(obs)
return self.normalize(obs)
else:
obs = self.normalize(np.array([obs]))[0]
if not return_info:
return obs
else:
return obs, info
return self.normalize(np.array([obs]))[0]
def normalize(self, obs):
"""Normalises the observation using the running mean and variance of the observations."""

View File

@@ -49,3 +49,8 @@ class OrderEnforcing(gym.Wrapper):
"set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper."
)
return self.env.render(*args, **kwargs)
@property
def has_reset(self):
"""Returns if the environment has been reset before."""
return self._has_reset

View File

@@ -120,8 +120,7 @@ class PixelObservationWrapper(gym.ObservationWrapper):
elif self._observation_is_dict:
self.observation_space = copy.deepcopy(wrapped_observation_space)
else:
self.observation_space = spaces.Dict()
self.observation_space.spaces[STATE_KEY] = wrapped_observation_space
self.observation_space = spaces.Dict({STATE_KEY: wrapped_observation_space})
# Extend observation space with pixels.
@@ -129,7 +128,7 @@ class PixelObservationWrapper(gym.ObservationWrapper):
pixels_spaces = {}
for pixel_key in pixel_keys:
pixels = self.env.render(**render_kwargs[pixel_key])
pixels = pixels[-1] if isinstance(pixels, List) else pixels
pixels: np.ndarray = pixels[-1] if isinstance(pixels, List) else pixels
if np.issubdtype(pixels.dtype, np.integer):
low, high = (0, 255)

View File

@@ -1,6 +1,7 @@
"""Wrapper that tracks the cumulative rewards and episode lengths."""
import time
from collections import deque
from typing import Optional
import numpy as np
@@ -86,8 +87,8 @@ class RecordEpisodeStatistics(gym.Wrapper):
self.num_envs = getattr(env, "num_envs", 1)
self.t0 = time.perf_counter()
self.episode_count = 0
self.episode_returns = None
self.episode_lengths = None
self.episode_returns: Optional[np.ndarray] = None
self.episode_lengths: Optional[np.ndarray] = None
self.return_queue = deque(maxlen=deque_size)
self.length_queue = deque(maxlen=deque_size)
self.is_vector_env = getattr(env, "is_vector_env", False)

View File

@@ -1,6 +1,6 @@
"""Wrapper for recording videos."""
import os
from typing import Callable
from typing import Callable, Optional
import gym
from gym import logger
@@ -77,7 +77,7 @@ class RecordVideo(gym.Wrapper): # TODO: remove with gym 1.0
self.episode_trigger = episode_trigger
self.step_trigger = step_trigger
self.video_recorder = None
self.video_recorder: Optional[video_recorder.VideoRecorder] = None
self.video_folder = os.path.abspath(video_folder)
# Create output folder if needed
@@ -101,6 +101,7 @@ class RecordVideo(gym.Wrapper): # TODO: remove with gym 1.0
"""Reset the environment using kwargs and then starts recording if video enabled."""
observations = super().reset(**kwargs)
if self.recording:
assert self.video_recorder is not None
self.video_recorder.capture_frame()
self.recorded_frames += 1
if self.video_length > 0:
@@ -148,6 +149,7 @@ class RecordVideo(gym.Wrapper): # TODO: remove with gym 1.0
self.episode_id += 1
if self.recording:
assert self.video_recorder is not None
self.video_recorder.capture_frame()
self.recorded_frames += 1
if self.video_length > 0:
@@ -168,6 +170,7 @@ class RecordVideo(gym.Wrapper): # TODO: remove with gym 1.0
def close_video_recorder(self):
"""Closes the video recorder if currently recording."""
if self.recording:
assert self.video_recorder is not None
self.video_recorder.close()
self.recording = False
self.recorded_frames = 1

View File

@@ -39,7 +39,10 @@ class ResizeObservation(gym.ObservationWrapper):
self.shape = tuple(shape)
obs_shape = self.shape + self.observation_space.shape[2:]
assert isinstance(
env.observation_space, Box
), f"Expected the observation space to be Box, actual type: {type(env.observation_space)}"
obs_shape = self.shape + env.observation_space.shape[2:]
self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)
def observation(self, observation):

View File

@@ -9,40 +9,7 @@ exclude = [
"**/node_modules",
"**/__pycache__",
"gym/envs/box2d/bipedal_walker.py",
"gym/envs/box2d/car_racing.py",
"gym/spaces/graph.py",
"gym/envs/mujoco/**",
"gym/utils/play.py",
"gym/vector/async_vector_env.py",
"gym/vector/utils/__init__.py",
"gym/wrappers/atari_preprocessing.py",
"gym/wrappers/gray_scale_observation.py",
"gym/wrappers/human_rendering.py",
"gym/wrappers/normalize.py",
"gym/wrappers/pixel_observation.py",
"gym/wrappers/record_video.py",
"gym/wrappers/monitoring/video_recorder.py",
"gym/wrappers/resize_observation.py",
"tests/envs/test_env_implementation.py",
"tests/utils/test_play.py",
"tests/vector/test_async_vector_env.py",
"tests/vector/test_shared_memory.py",
"tests/vector/test_spaces.py",
"tests/vector/test_sync_vector_env.py",
"tests/vector/test_vector_env.py",
"tests/wrappers/test_gray_scale_observation.py",
"tests/wrappers/test_order_enforcing.py",
"tests/wrappers/test_record_episode_statistics.py",
"tests/wrappers/test_resize_observation.py",
"tests/wrappers/test_time_aware_observation.py",
"tests/wrappers/test_video_recorder.py",
"gym/envs/mujoco/mujoco_env.py",
]
strict = [
@@ -51,6 +18,7 @@ strict = [
typeCheckingMode = "basic"
pythonVersion = "3.6"
pythonPlatform = "All"
typeshedPath = "typeshed"
enableTypeIgnoreComments = true

View File

@@ -1,5 +1,5 @@
"""Finds all the specs that we can test with"""
from typing import Optional
from typing import List, Optional
import numpy as np
@@ -18,22 +18,27 @@ def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]:
# Tries to make all environment to test with
all_testing_initialised_envs = list(
filter(None, [try_make_env(env_spec) for env_spec in gym.envs.registry.values()])
)
all_testing_initialised_envs: List[Optional[gym.Env]] = [
try_make_env(env_spec) for env_spec in gym.envs.registry.values()
]
all_testing_initialised_envs: List[gym.Env] = [
env for env in all_testing_initialised_envs if env is not None
]
# All testing, mujoco and gym environment specs
all_testing_env_specs = [env.spec for env in all_testing_initialised_envs]
mujoco_testing_env_specs = [
all_testing_env_specs: List[EnvSpec] = [
env.spec for env in all_testing_initialised_envs
]
mujoco_testing_env_specs: List[EnvSpec] = [
env_spec
for env_spec in all_testing_env_specs
if "gym.envs.mujoco" in env_spec.entry_point
]
gym_testing_env_specs = [
gym_testing_env_specs: List[EnvSpec] = [
env_spec
for env_spec in all_testing_env_specs
if any(
f"gym.{ep}" in env_spec.entry_point
f"gym.envs.{ep}" in env_spec.entry_point
for ep in ["box2d", "classic_control", "toy_text"]
)
]

View File

@@ -194,7 +194,7 @@ def test_play_loop_real_env():
# first action is 0 because at the first iteration
# we can not inject a callback event into play()
env.step(0)
obs, _, _, _ = env.step(0)
for e in keydown_events:
action = keys_to_action[chr(e.key) if str_keys else (e.key,)]
obs, _, _, _ = env.step(action)

View File

@@ -17,21 +17,19 @@ from tests.vector.utils import (
@pytest.mark.parametrize("shared_memory", [True, False])
def test_create_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
finally:
env.close()
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
assert env.num_envs == 8
env.close()
@pytest.mark.parametrize("shared_memory", [True, False])
def test_reset_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations = env.reset()
finally:
env.close()
assert isinstance(env.observation_space, Box)
@@ -71,7 +69,7 @@ def test_reset_async_vector_env(shared_memory):
@pytest.mark.parametrize("use_single_action_space", [True, False])
def test_step_async_vector_env(shared_memory, use_single_action_space):
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations = env.reset()
@@ -83,7 +81,7 @@ def test_step_async_vector_env(shared_memory, use_single_action_space):
else:
actions = env.action_space.sample()
observations, rewards, dones, _ = env.step(actions)
finally:
env.close()
assert isinstance(env.observation_space, Box)
@@ -106,12 +104,12 @@ def test_step_async_vector_env(shared_memory, use_single_action_space):
@pytest.mark.parametrize("shared_memory", [True, False])
def test_call_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i, render_mode="rgb_array") for i in range(4)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
_ = env.reset()
images = env.call("render")
gravity = env.call("gravity")
finally:
env.close()
assert isinstance(images, tuple)
@@ -130,59 +128,60 @@ def test_call_async_vector_env(shared_memory):
@pytest.mark.parametrize("shared_memory", [True, False])
def test_set_attr_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
env.set_attr("gravity", [9.81, 3.72, 8.87, 1.62])
gravity = env.get_attr("gravity")
assert gravity == (9.81, 3.72, 8.87, 1.62)
finally:
env.close()
@pytest.mark.parametrize("shared_memory", [True, False])
def test_copy_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
# TODO, these tests do nothing, understand the purpose of the tests and fix them
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=True)
observations = env.reset()
observations[0] = 0
finally:
env.close()
@pytest.mark.parametrize("shared_memory", [True, False])
def test_no_copy_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
# TODO, these tests do nothing, understand the purpose of the tests and fix them
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=False)
observations = env.reset()
observations[0] = 0
finally:
env.close()
@pytest.mark.parametrize("shared_memory", [True, False])
def test_reset_timeout_async_vector_env(shared_memory):
env_fns = [make_slow_env(0.3, i) for i in range(4)]
with pytest.raises(TimeoutError):
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
with pytest.raises(TimeoutError):
env.reset_async()
env.reset_wait(timeout=0.1)
finally:
env.close(terminate=True)
@pytest.mark.parametrize("shared_memory", [True, False])
def test_step_timeout_async_vector_env(shared_memory):
env_fns = [make_slow_env(0.0, i) for i in range(4)]
with pytest.raises(TimeoutError):
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
with pytest.raises(TimeoutError):
env.reset()
env.step_async([0.1, 0.1, 0.3, 0.1])
env.step_async(np.array([0.1, 0.1, 0.3, 0.1]))
observations, rewards, dones, _ = env.step_wait(timeout=0.1)
finally:
env.close(terminate=True)
@@ -190,19 +189,20 @@ def test_step_timeout_async_vector_env(shared_memory):
@pytest.mark.parametrize("shared_memory", [True, False])
def test_reset_out_of_order_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
with pytest.raises(NoAsyncCallError):
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
env.reset_wait()
except NoAsyncCallError as exception:
assert exception.name == "reset"
raise
finally:
env.close(terminate=True)
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
with pytest.raises(AlreadyPendingCallError):
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
actions = env.action_space.sample()
env.reset()
env.step_async(actions)
@@ -210,7 +210,7 @@ def test_reset_out_of_order_async_vector_env(shared_memory):
except NoAsyncCallError as exception:
assert exception.name == "step"
raise
finally:
env.close(terminate=True)
@@ -218,28 +218,29 @@ def test_reset_out_of_order_async_vector_env(shared_memory):
@pytest.mark.parametrize("shared_memory", [True, False])
def test_step_out_of_order_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
with pytest.raises(NoAsyncCallError):
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
actions = env.action_space.sample()
observations = env.reset()
observations, rewards, dones, infos = env.step_wait()
env.action_space.sample()
env.reset()
env.step_wait()
except AlreadyPendingCallError as exception:
assert exception.name == "step"
raise
finally:
env.close(terminate=True)
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
with pytest.raises(AlreadyPendingCallError):
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
actions = env.action_space.sample()
env.reset_async()
env.step_async(actions)
except AlreadyPendingCallError as exception:
assert exception.name == "reset"
raise
finally:
env.close(terminate=True)
@@ -265,7 +266,7 @@ def test_check_spaces_async_vector_env(shared_memory):
def test_custom_space_async_vector_env():
env_fns = [make_custom_space_env(i) for i in range(4)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=False)
reset_observations = env.reset()
@@ -274,7 +275,7 @@ def test_custom_space_async_vector_env():
actions = ("action-2", "action-3", "action-5", "action-7")
step_observations, rewards, dones, _ = env.step(actions)
finally:
env.close()
assert isinstance(env.single_observation_space, CustomSpace)

View File

@@ -116,6 +116,7 @@ def test_batch_space_custom_space(space, expected_batch_space_4):
def test_iterate(space, batch_space):
items = batch_space.sample()
iterator = iterate(batch_space, items)
i = 0
for i, item in enumerate(iterator):
assert item in space
assert i == 3
@@ -129,6 +130,7 @@ def test_iterate(space, batch_space):
def test_iterate_custom_space(space, batch_space):
items = batch_space.sample()
iterator = iterate(batch_space, items)
i = 0
for i, item in enumerate(iterator):
assert item in space
assert i == 3

View File

@@ -15,9 +15,7 @@ from tests.vector.utils import (
def test_create_sync_vector_env():
env_fns = [make_env("FrozenLake-v1", i) for i in range(8)]
try:
env = SyncVectorEnv(env_fns)
finally:
env.close()
assert env.num_envs == 8
@@ -25,10 +23,8 @@ def test_create_sync_vector_env():
def test_reset_sync_vector_env():
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
env = SyncVectorEnv(env_fns)
observations = env.reset()
finally:
env.close()
assert isinstance(env.observation_space, Box)
@@ -39,10 +35,8 @@ def test_reset_sync_vector_env():
del observations
try:
env = SyncVectorEnv(env_fns)
observations = env.reset(return_info=False)
finally:
env.close()
assert isinstance(env.observation_space, Box)
@@ -54,10 +48,9 @@ def test_reset_sync_vector_env():
del observations
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
env = SyncVectorEnv(env_fns)
observations, infos = env.reset(return_info=True)
finally:
env.close()
assert isinstance(env.observation_space, Box)
@@ -72,7 +65,7 @@ def test_reset_sync_vector_env():
@pytest.mark.parametrize("use_single_action_space", [True, False])
def test_step_sync_vector_env(use_single_action_space):
env_fns = [make_env("FrozenLake-v1", i) for i in range(8)]
try:
env = SyncVectorEnv(env_fns)
observations = env.reset()
@@ -84,7 +77,7 @@ def test_step_sync_vector_env(use_single_action_space):
else:
actions = env.action_space.sample()
observations, rewards, dones, _ = env.step(actions)
finally:
env.close()
assert isinstance(env.observation_space, MultiDiscrete)
@@ -106,12 +99,12 @@ def test_step_sync_vector_env(use_single_action_space):
def test_call_sync_vector_env():
env_fns = [make_env("CartPole-v1", i, render_mode="rgb_array") for i in range(4)]
try:
env = SyncVectorEnv(env_fns)
_ = env.reset()
images = env.call("render")
gravity = env.call("gravity")
finally:
env.close()
assert isinstance(images, tuple)
@@ -129,12 +122,12 @@ def test_call_sync_vector_env():
def test_set_attr_sync_vector_env():
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
try:
env = SyncVectorEnv(env_fns)
env.set_attr("gravity", [9.81, 3.72, 8.87, 1.62])
gravity = env.get_attr("gravity")
assert gravity == (9.81, 3.72, 8.87, 1.62)
finally:
env.close()
@@ -150,7 +143,7 @@ def test_check_spaces_sync_vector_env():
def test_custom_space_sync_vector_env():
env_fns = [make_custom_space_env(i) for i in range(4)]
try:
env = SyncVectorEnv(env_fns)
reset_observations = env.reset()
@@ -159,7 +152,7 @@ def test_custom_space_sync_vector_env():
actions = ("action-2", "action-3", "action-5", "action-7")
step_observations, rewards, dones, _ = env.step(actions)
finally:
env.close()
assert isinstance(env.single_observation_space, CustomSpace)

View File

@@ -12,7 +12,7 @@ from tests.vector.utils import CustomSpace, make_env
def test_vector_env_equal(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
num_steps = 100
try:
async_env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
sync_env = SyncVectorEnv(env_fns)
@@ -45,7 +45,6 @@ def test_vector_env_equal(shared_memory):
assert np.all(async_rewards == sync_rewards)
assert np.all(async_dones == sync_dones)
finally:
async_env.close()
sync_env.close()

View File

@@ -2,6 +2,7 @@ import numpy as np
import pytest
import gym
from gym import spaces
from gym.wrappers import AtariPreprocessing, GrayScaleObservation
pytest.importorskip("gym.envs.atari")
@@ -20,8 +21,12 @@ def test_gray_scale_observation(env_id, keep_dim):
gym.make(env_id, disable_env_checker=True), screen_size=84, grayscale_obs=False
)
wrapped_env = GrayScaleObservation(rgb_env, keep_dim=keep_dim)
assert isinstance(rgb_env.observation_space, spaces.Box)
assert rgb_env.observation_space.shape[-1] == 3
assert isinstance(wrapped_env.observation_space, spaces.Box)
seed = 0
gray_obs = gray_env.reset(seed=seed)

View File

@@ -26,16 +26,16 @@ def test_order_enforcing():
# Assert that the order enforcing works for step and render before reset
order_enforced_env = OrderEnforcing(env)
assert order_enforced_env._has_reset is False
assert order_enforced_env.has_reset is False
with pytest.raises(ResetNeeded):
order_enforced_env.step(0)
with pytest.raises(ResetNeeded):
order_enforced_env.render(mode="rgb_array")
assert order_enforced_env._has_reset is False
assert order_enforced_env.has_reset is False
# Assert that the Assertion errors are not raised after reset
order_enforced_env.reset()
assert order_enforced_env._has_reset is True
assert order_enforced_env.has_reset is True
order_enforced_env.step(0)
order_enforced_env.render(mode="rgb_array")

View File

@@ -14,6 +14,7 @@ def test_record_episode_statistics(env_id, deque_size):
for n in range(5):
env.reset()
assert env.episode_returns is not None and env.episode_lengths is not None
assert env.episode_returns[0] == 0.0
assert env.episode_lengths[0] == 0
for t in range(env.spec.max_episode_steps):

View File

@@ -1,6 +1,7 @@
import pytest
import gym
from gym import spaces
from gym.wrappers import ResizeObservation
pytest.importorskip("gym.envs.atari")
@@ -14,6 +15,7 @@ def test_resize_observation(env_id, shape):
env = gym.make(env_id, disable_env_checker=True)
env = ResizeObservation(env, shape)
assert isinstance(env.observation_space, spaces.Box)
assert env.observation_space.shape[-1] == 3
obs = env.reset()
if isinstance(shape, int):

View File

@@ -1,6 +1,7 @@
import pytest
import gym
from gym import spaces
from gym.wrappers import TimeAwareObservation
@@ -9,6 +10,8 @@ def test_time_aware_observation(env_id):
env = gym.make(env_id, disable_env_checker=True)
wrapped_env = TimeAwareObservation(env)
assert isinstance(env.observation_space, spaces.Box)
assert isinstance(wrapped_env.observation_space, spaces.Box)
assert wrapped_env.observation_space.shape[0] == env.observation_space.shape[0] + 1
obs = env.reset()

View File

@@ -33,9 +33,10 @@ def test_record_simple():
rec = VideoRecorder(env)
env.reset()
rec.capture_frame()
assert rec.encoder is not None
proc = rec.encoder.proc
assert proc.poll() is None # subprocess is running
assert proc is not None and proc.poll() is None # subprocess is running
rec.close()
@@ -55,9 +56,10 @@ def test_autoclose():
rec.capture_frame()
rec_path = rec.path
assert rec.encoder is not None
proc = rec.encoder.proc
assert proc.poll() is None # subprocess is running
assert proc is not None and proc.poll() is None # subprocess is running
# The function ends without an explicit `rec.close()` call
# The Python interpreter will implicitly do `del rec` on garbage cleaning
@@ -68,7 +70,7 @@ def test_autoclose():
gc.collect() # do explicit garbage collection for test
time.sleep(5) # wait for subprocess exiting
assert proc.poll() is not None # subprocess is terminated
assert proc is not None and proc.poll() is not None # subprocess is terminated
assert os.path.exists(rec_path)
f = open(rec_path)
assert os.fstat(f.fileno()).st_size > 100