Full type hinting (#2942)

* Allows a new RNG to be generated with seed=-1 and updated env_checker to fix bug if environment doesn't use np_random in reset

* Revert "fixed `gym.vector.make` where the checker was being applied in the opposite case than was intended to (#2871)"

This reverts commit 519dfd9117.

* Remove bad pushed commits

* Fixed spelling in core.py

* Pins pytest to the last py 3.6 version

* Allow Box automatic scalar shape

* Add test box and change default from () to (1,)

* update Box shape inference with more strict checking

* Update the box shape and add check on the custom Box shape

* Removed incorrect shape type and assert shape code

* Update the Box and associated tests

* Remove all folders and files from pyright exclude

* Revert issues

* Push RedTachyon code review

* Add Python Platform

* Remove play from pyright check

* Fixed CI issues

* remove mujoco env type hinting

* Fixed pixel observation test

* Added some new type hints

* Fixed CI errors

* Fixed CI errors

* Remove play.py from exlucde pyright

* Fixed pyright issues
This commit is contained in:
Mark Towers
2022-07-04 18:19:25 +01:00
committed by GitHub
parent 9e66399b4e
commit 2ede09074f
61 changed files with 352 additions and 286 deletions

View File

@@ -1,6 +1,7 @@
"""Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper.""" """Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper."""
import sys import sys
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Dict, Dict,
Generic, Generic,
@@ -17,6 +18,9 @@ from gym.logger import deprecation, warn
from gym.utils import seeding from gym.utils import seeding
from gym.utils.seeding import RandomNumberGenerator from gym.utils.seeding import RandomNumberGenerator
if TYPE_CHECKING:
from gym.envs.registration import EnvSpec
if sys.version_info[0:2] == (3, 6): if sys.version_info[0:2] == (3, 6):
warn( warn(
"Gym minimally supports python 3.6 as the python foundation not longer supports the version, please update your version to 3.7+" "Gym minimally supports python 3.6 as the python foundation not longer supports the version, please update your version to 3.7+"
@@ -106,7 +110,7 @@ class Env(Generic[ObsType, ActType], metaclass=decorator):
metadata = {"render_modes": []} metadata = {"render_modes": []}
render_mode = None # define render_mode if your environment supports rendering render_mode = None # define render_mode if your environment supports rendering
reward_range = (-float("inf"), float("inf")) reward_range = (-float("inf"), float("inf"))
spec = None spec: "EnvSpec" = None
# Set these in ALL subclasses # Set these in ALL subclasses
action_space: spaces.Space[ActType] action_space: spaces.Space[ActType]

View File

@@ -1,7 +1,7 @@
__credits__ = ["Andrea PIERRÉ"] __credits__ = ["Andrea PIERRÉ"]
import math import math
from typing import Optional from typing import TYPE_CHECKING, List, Optional
import numpy as np import numpy as np
@@ -25,6 +25,9 @@ except ImportError:
raise DependencyNotInstalled("box2D is not installed, run `pip install gym[box2d]`") raise DependencyNotInstalled("box2D is not installed, run `pip install gym[box2d]`")
if TYPE_CHECKING:
import pygame
FPS = 50 FPS = 50
SCALE = 30.0 # affects how fast-paced the game is, forces should be adjusted as well SCALE = 30.0 # affects how fast-paced the game is, forces should be adjusted as well
@@ -170,8 +173,8 @@ class BipedalWalker(gym.Env, EzPickle):
self.isopen = True self.isopen = True
self.world = Box2D.b2World() self.world = Box2D.b2World()
self.terrain = None self.terrain: List[Box2D.b2Body] = []
self.hull = None self.hull: Optional[Box2D.b2Body] = None
self.prev_shaping = None self.prev_shaping = None
@@ -256,7 +259,7 @@ class BipedalWalker(gym.Env, EzPickle):
self.render_mode = render_mode self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render) self.renderer = Renderer(self.render_mode, self._render)
self.screen = None self.screen: Optional[pygame.Surface] = None
self.clock = None self.clock = None
def _destroy(self): def _destroy(self):
@@ -283,6 +286,9 @@ class BipedalWalker(gym.Env, EzPickle):
self.terrain = [] self.terrain = []
self.terrain_x = [] self.terrain_x = []
self.terrain_y = [] self.terrain_y = []
stair_steps, stair_width, stair_height = 0, 0, 0
original_y = 0
for i in range(TERRAIN_LENGTH): for i in range(TERRAIN_LENGTH):
x = i * TERRAIN_STEP x = i * TERRAIN_STEP
self.terrain_x.append(x) self.terrain_x.append(x)
@@ -448,8 +454,8 @@ class BipedalWalker(gym.Env, EzPickle):
(self.np_random.uniform(-INITIAL_RANDOM, INITIAL_RANDOM), 0), True (self.np_random.uniform(-INITIAL_RANDOM, INITIAL_RANDOM), 0), True
) )
self.legs = [] self.legs: List[Box2D.b2Body] = []
self.joints = [] self.joints: List[Box2D.b2RevoluteJoint] = []
for i in [-1, +1]: for i in [-1, +1]:
leg = self.world.CreateDynamicBody( leg = self.world.CreateDynamicBody(
position=(init_x, init_y - LEG_H / 2 - LEG_DOWN), position=(init_x, init_y - LEG_H / 2 - LEG_DOWN),
@@ -514,6 +520,8 @@ class BipedalWalker(gym.Env, EzPickle):
return self.step(np.array([0, 0, 0, 0]))[0], {} return self.step(np.array([0, 0, 0, 0]))[0], {}
def step(self, action: np.ndarray): def step(self, action: np.ndarray):
assert self.hull is not None
# self.hull.ApplyForceToCenter((0, 20), True) -- Uncomment this to receive a bit of stability help # self.hull.ApplyForceToCenter((0, 20), True) -- Uncomment this to receive a bit of stability help
control_speed = False # Should be easier as well control_speed = False # Should be easier as well
if control_speed: if control_speed:
@@ -737,6 +745,7 @@ class BipedalWalker(gym.Env, EzPickle):
self.surf = pygame.transform.flip(self.surf, False, True) self.surf = pygame.transform.flip(self.surf, False, True)
if mode == "human": if mode == "human":
assert self.screen is not None
self.screen.blit(self.surf, (-self.scroll * SCALE, 0)) self.screen.blit(self.surf, (-self.scroll * SCALE, 0))
pygame.event.pump() pygame.event.pump()
self.clock.tick(self.metadata["render_fps"]) self.clock.tick(self.metadata["render_fps"])

View File

@@ -9,6 +9,7 @@ Created by Oleg Klimov
import math import math
import Box2D
import numpy as np import numpy as np
from gym.error import DependencyNotInstalled from gym.error import DependencyNotInstalled
@@ -48,8 +49,8 @@ MUD_COLOR = (102, 102, 0)
class Car: class Car:
def __init__(self, world, init_angle, init_x, init_y): def __init__(self, world, init_angle, init_x, init_y):
self.world = world self.world: Box2D.b2World = world
self.hull = self.world.CreateDynamicBody( self.hull: Box2D.b2Body = self.world.CreateDynamicBody(
position=(init_x, init_y), position=(init_x, init_y),
angle=init_angle, angle=init_angle,
fixtures=[ fixtures=[

View File

@@ -197,14 +197,14 @@ class CarRacing(gym.Env, EzPickle):
self.contactListener_keepref = FrictionDetector(self, lap_complete_percent) self.contactListener_keepref = FrictionDetector(self, lap_complete_percent)
self.world = Box2D.b2World((0, 0), contactListener=self.contactListener_keepref) self.world = Box2D.b2World((0, 0), contactListener=self.contactListener_keepref)
self.screen = None self.screen: Optional[pygame.Surface] = None
self.surf = None self.surf = None
self.clock = None self.clock = None
self.isopen = True self.isopen = True
self.invisible_state_window = None self.invisible_state_window = None
self.invisible_video_window = None self.invisible_video_window = None
self.road = None self.road = None
self.car = None self.car: Optional[Car] = None
self.reward = 0.0 self.reward = 0.0
self.prev_reward = 0.0 self.prev_reward = 0.0
self.verbose = verbose self.verbose = verbose
@@ -237,6 +237,7 @@ class CarRacing(gym.Env, EzPickle):
for t in self.road: for t in self.road:
self.world.DestroyBody(t) self.world.DestroyBody(t)
self.road = [] self.road = []
assert self.car is not None
self.car.destroy() self.car.destroy()
def _init_colors(self): def _init_colors(self):
@@ -502,6 +503,7 @@ class CarRacing(gym.Env, EzPickle):
return self.step(None)[0], {} return self.step(None)[0], {}
def step(self, action: Union[np.ndarray, int]): def step(self, action: Union[np.ndarray, int]):
assert self.car is not None
if action is not None: if action is not None:
if self.continuous: if self.continuous:
self.car.steer(-action[0]) self.car.steer(-action[0])
@@ -576,6 +578,7 @@ class CarRacing(gym.Env, EzPickle):
self.surf = pygame.Surface((WINDOW_W, WINDOW_H)) self.surf = pygame.Surface((WINDOW_W, WINDOW_H))
assert self.car is not None
# computing transformations # computing transformations
angle = -self.car.hull.angle angle = -self.car.hull.angle
# Animating first second zoom. # Animating first second zoom.
@@ -608,6 +611,7 @@ class CarRacing(gym.Env, EzPickle):
if mode == "human": if mode == "human":
pygame.event.pump() pygame.event.pump()
self.clock.tick(self.metadata["render_fps"]) self.clock.tick(self.metadata["render_fps"])
assert self.screen is not None
self.screen.fill(0) self.screen.fill(0)
self.screen.blit(self.surf, (0, 0)) self.screen.blit(self.surf, (0, 0))
pygame.display.flip() pygame.display.flip()
@@ -682,6 +686,7 @@ class CarRacing(gym.Env, EzPickle):
((place + 0) * s, H - 2 * h), ((place + 0) * s, H - 2 * h),
] ]
assert self.car is not None
true_speed = np.sqrt( true_speed = np.sqrt(
np.square(self.car.hull.linearVelocity[0]) np.square(self.car.hull.linearVelocity[0])
+ np.square(self.car.hull.linearVelocity[1]) + np.square(self.car.hull.linearVelocity[1])

View File

@@ -2,7 +2,7 @@ __credits__ = ["Andrea PIERRÉ"]
import math import math
import warnings import warnings
from typing import Optional from typing import TYPE_CHECKING, Optional
import numpy as np import numpy as np
@@ -25,6 +25,11 @@ try:
except ImportError: except ImportError:
raise DependencyNotInstalled("box2d is not installed, run `pip install gym[box2d]`") raise DependencyNotInstalled("box2d is not installed, run `pip install gym[box2d]`")
if TYPE_CHECKING:
import pygame
FPS = 50 FPS = 50
SCALE = 30.0 # affects how fast-paced the game is, forces should be adjusted as well SCALE = 30.0 # affects how fast-paced the game is, forces should be adjusted as well
@@ -215,7 +220,7 @@ class LunarLander(gym.Env, EzPickle):
self.wind_idx = np.random.randint(-9999, 9999) self.wind_idx = np.random.randint(-9999, 9999)
self.torque_idx = np.random.randint(-9999, 9999) self.torque_idx = np.random.randint(-9999, 9999)
self.screen = None self.screen: pygame.Surface = None
self.clock = None self.clock = None
self.isopen = True self.isopen = True
self.world = Box2D.b2World(gravity=(0, gravity)) self.world = Box2D.b2World(gravity=(0, gravity))
@@ -427,6 +432,8 @@ class LunarLander(gym.Env, EzPickle):
self.world.DestroyBody(self.particles.pop(0)) self.world.DestroyBody(self.particles.pop(0))
def step(self, action): def step(self, action):
assert self.lander is not None
# Update wind # Update wind
assert self.lander is not None, "You forgot to call reset()" assert self.lander is not None, "You forgot to call reset()"
if self.enable_wind and not ( if self.enable_wind and not (
@@ -604,10 +611,6 @@ class LunarLander(gym.Env, EzPickle):
if self.clock is None: if self.clock is None:
self.clock = pygame.time.Clock() self.clock = pygame.time.Clock()
assert (
self.screen is not None
), "Something went wrong with pygame, there is no screen to render"
self.surf = pygame.Surface((VIEWPORT_W, VIEWPORT_H)) self.surf = pygame.Surface((VIEWPORT_W, VIEWPORT_H))
pygame.transform.scale(self.surf, (SCALE, SCALE)) pygame.transform.scale(self.surf, (SCALE, SCALE))

View File

@@ -79,4 +79,5 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return self._get_obs() return self._get_obs()
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.distance = self.model.stat.extent * 0.5 self.viewer.cam.distance = self.model.stat.extent * 0.5

View File

@@ -173,6 +173,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation return observation
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items(): for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value getattr(self.viewer.cam, key)[:] = value

View File

@@ -336,6 +336,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation return observation
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items(): for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value getattr(self.viewer.cam, key)[:] = value

View File

@@ -60,4 +60,5 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return self._get_obs() return self._get_obs()
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.distance = self.model.stat.extent * 0.5 self.viewer.cam.distance = self.model.stat.extent * 0.5

View File

@@ -118,6 +118,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation return observation
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items(): for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value getattr(self.viewer.cam, key)[:] = value

View File

@@ -231,6 +231,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation return observation
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items(): for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value getattr(self.viewer.cam, key)[:] = value

View File

@@ -66,6 +66,7 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return self._get_obs() return self._get_obs()
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.trackbodyid = 2 self.viewer.cam.trackbodyid = 2
self.viewer.cam.distance = self.model.stat.extent * 0.75 self.viewer.cam.distance = self.model.stat.extent * 0.75
self.viewer.cam.lookat[2] = 1.15 self.viewer.cam.lookat[2] = 1.15

View File

@@ -163,6 +163,7 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation return observation
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items(): for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value getattr(self.viewer.cam, key)[:] = value

View File

@@ -279,6 +279,7 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation return observation
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items(): for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value getattr(self.viewer.cam, key)[:] = value

View File

@@ -92,6 +92,7 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return self._get_obs() return self._get_obs()
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.trackbodyid = 1 self.viewer.cam.trackbodyid = 1
self.viewer.cam.distance = self.model.stat.extent * 1.0 self.viewer.cam.distance = self.model.stat.extent * 1.0
self.viewer.cam.lookat[2] = 2.0 self.viewer.cam.lookat[2] = 2.0

View File

@@ -185,6 +185,7 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation return observation
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items(): for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value getattr(self.viewer.cam, key)[:] = value

View File

@@ -358,6 +358,7 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation return observation
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items(): for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value getattr(self.viewer.cam, key)[:] = value

View File

@@ -83,6 +83,7 @@ class HumanoidStandupEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return self._get_obs() return self._get_obs()
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.trackbodyid = 1 self.viewer.cam.trackbodyid = 1
self.viewer.cam.distance = self.model.stat.extent * 1.0 self.viewer.cam.distance = self.model.stat.extent * 1.0
self.viewer.cam.lookat[2] = 0.8925 self.viewer.cam.lookat[2] = 0.8925

View File

@@ -254,6 +254,7 @@ class HumanoidStandupEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return self._get_obs() return self._get_obs()
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.trackbodyid = 1 self.viewer.cam.trackbodyid = 1
self.viewer.cam.distance = self.model.stat.extent * 1.0 self.viewer.cam.distance = self.model.stat.extent * 1.0
self.viewer.cam.lookat[2] = 0.8925 self.viewer.cam.lookat[2] = 0.8925

View File

@@ -64,6 +64,7 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return self._get_obs() return self._get_obs()
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
v = self.viewer v = self.viewer
v.cam.trackbodyid = 0 v.cam.trackbodyid = 0
v.cam.distance = self.model.stat.extent * 0.5 v.cam.distance = self.model.stat.extent * 0.5

View File

@@ -169,6 +169,7 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return self._get_obs() return self._get_obs()
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
v = self.viewer v = self.viewer
v.cam.trackbodyid = 0 v.cam.trackbodyid = 0
v.cam.distance = self.model.stat.extent * 0.5 v.cam.distance = self.model.stat.extent * 0.5

View File

@@ -54,6 +54,6 @@ class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return np.concatenate([self.sim.data.qpos, self.sim.data.qvel]).ravel() return np.concatenate([self.sim.data.qpos, self.sim.data.qvel]).ravel()
def viewer_setup(self): def viewer_setup(self):
v = self.viewer assert self.viewer is not None
v.cam.trackbodyid = 0 self.viewer.cam.trackbodyid = 0
v.cam.distance = self.model.stat.extent self.viewer.cam.distance = self.model.stat.extent

View File

@@ -130,6 +130,7 @@ class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return np.concatenate([self.data.qpos, self.data.qvel]).ravel() return np.concatenate([self.data.qpos, self.data.qvel]).ravel()
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
v = self.viewer v = self.viewer
v.cam.trackbodyid = 0 v.cam.trackbodyid = 0
v.cam.distance = self.model.stat.extent v.cam.distance = self.model.stat.extent

View File

@@ -47,6 +47,7 @@ class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl) return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.trackbodyid = -1 self.viewer.cam.trackbodyid = -1
self.viewer.cam.distance = 4.0 self.viewer.cam.distance = 4.0

View File

@@ -164,6 +164,7 @@ class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl) return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.trackbodyid = -1 self.viewer.cam.trackbodyid = -1
self.viewer.cam.distance = 4.0 self.viewer.cam.distance = 4.0

View File

@@ -43,6 +43,7 @@ class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl) return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.trackbodyid = 0 self.viewer.cam.trackbodyid = 0
def reset_model(self): def reset_model(self):

View File

@@ -150,6 +150,7 @@ class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl) return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.trackbodyid = 0 self.viewer.cam.trackbodyid = 0
def reset_model(self): def reset_model(self):

View File

@@ -119,6 +119,7 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation return observation
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items(): for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value getattr(self.viewer.cam, key)[:] = value

View File

@@ -225,6 +225,7 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation return observation
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items(): for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value getattr(self.viewer.cam, key)[:] = value

View File

@@ -60,6 +60,7 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return self._get_obs() return self._get_obs()
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.trackbodyid = 2 self.viewer.cam.trackbodyid = 2
self.viewer.cam.distance = self.model.stat.extent * 0.5 self.viewer.cam.distance = self.model.stat.extent * 0.5
self.viewer.cam.lookat[2] = 1.15 self.viewer.cam.lookat[2] = 1.15

View File

@@ -153,6 +153,7 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation return observation
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items(): for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value getattr(self.viewer.cam, key)[:] = value

View File

@@ -277,6 +277,7 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
return observation return observation
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items(): for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value getattr(self.viewer.cam, key)[:] = value

View File

@@ -134,10 +134,12 @@ class Graph(Space):
assert ( assert (
num_edges >= 0 num_edges >= 0
), f"The number of edges is expected to be greater than 0, actual mask: {num_edges}" ), f"The number of edges is expected to be greater than 0, actual mask: {num_edges}"
assert num_edges is not None
sampled_node_space = self._generate_sample_space(self.node_space, num_nodes) sampled_node_space = self._generate_sample_space(self.node_space, num_nodes)
sampled_edge_space = self._generate_sample_space(self.edge_space, num_edges) sampled_edge_space = self._generate_sample_space(self.edge_space, num_edges)
assert sampled_node_space is not None
sampled_nodes = sampled_node_space.sample(node_space_mask) sampled_nodes = sampled_node_space.sample(node_space_mask)
sampled_edges = ( sampled_edges = (
sampled_edge_space.sample(edge_space_mask) sampled_edge_space.sample(edge_space_mask)

View File

@@ -3,16 +3,23 @@ from collections import deque
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import pygame
from pygame import Surface
from pygame.event import Event
from pygame.locals import VIDEORESIZE
import gym.error
from gym import Env, logger from gym import Env, logger
from gym.core import ActType, ObsType from gym.core import ActType, ObsType
from gym.error import DependencyNotInstalled from gym.error import DependencyNotInstalled
from gym.logger import deprecation from gym.logger import deprecation
try:
import pygame
from pygame import Surface
from pygame.event import Event
from pygame.locals import VIDEORESIZE
except ImportError:
raise gym.error.DependencyNotInstalled(
"Pygame is not installed, run `pip install gym[classic_control]`"
)
try: try:
import matplotlib import matplotlib
@@ -20,7 +27,7 @@ try:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
except ImportError: except ImportError:
logger.warn("Matplotlib is not installed, run `pip install gym[other]`") logger.warn("Matplotlib is not installed, run `pip install gym[other]`")
matplotlib, plt = None, None plt = None
class MissingKeysToAction(Exception): class MissingKeysToAction(Exception):
@@ -33,7 +40,7 @@ class PlayableGame:
def __init__( def __init__(
self, self,
env: Env, env: Env,
keys_to_action: Optional[Dict[Tuple[int], int]] = None, keys_to_action: Optional[Dict[Tuple[int, ...], int]] = None,
zoom: Optional[float] = None, zoom: Optional[float] = None,
): ):
"""Wraps an environment with a dictionary of keyboard buttons to action and if to zoom in on the environment. """Wraps an environment with a dictionary of keyboard buttons to action and if to zoom in on the environment.
@@ -63,12 +70,14 @@ class PlayableGame:
f"{self.env.spec.id} does not have explicit key to action mapping, " f"{self.env.spec.id} does not have explicit key to action mapping, "
"please specify one manually" "please specify one manually"
) )
assert isinstance(keys_to_action, dict)
relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), [])) relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), []))
return relevant_keys return relevant_keys
def _get_video_size(self, zoom: Optional[float] = None) -> Tuple[int, int]: def _get_video_size(self, zoom: Optional[float] = None) -> Tuple[int, int]:
# TODO: this needs to be updated when the render API change goes through # TODO: this needs to be updated when the render API change goes through
rendered = self.env.render(mode="rgb_array") rendered = self.env.render(mode="rgb_array")
assert rendered is not None and isinstance(rendered, np.ndarray)
video_size = [rendered.shape[1], rendered.shape[0]] video_size = [rendered.shape[1], rendered.shape[0]]
if zoom is not None: if zoom is not None:
@@ -211,9 +220,9 @@ def play(
f"{env.spec.id} does not have explicit key to action mapping, " f"{env.spec.id} does not have explicit key to action mapping, "
"please specify one manually" "please specify one manually"
) )
assert keys_to_action is not None
key_code_to_action = {} key_code_to_action = {}
for key_combination, action in keys_to_action.items(): for key_combination, action in keys_to_action.items():
key_code = tuple( key_code = tuple(
sorted(ord(key) if isinstance(key, str) else key for key in key_combination) sorted(ord(key) if isinstance(key, str) else key for key in key_combination)
@@ -225,7 +234,7 @@ def play(
if fps is None: if fps is None:
fps = env.metadata.get("render_fps", 30) fps = env.metadata.get("render_fps", 30)
done = True done, obs = True, None
clock = pygame.time.Clock() clock = pygame.time.Clock()
while game.running: while game.running:
@@ -316,7 +325,7 @@ class PlayPlot:
for axis, name in zip(self.ax, plot_names): for axis, name in zip(self.ax, plot_names):
axis.set_title(name) axis.set_title(name)
self.t = 0 self.t = 0
self.cur_plot = [None for _ in range(num_plots)] self.cur_plot: List[Optional[plt.Axes]] = [None for _ in range(num_plots)]
self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)] self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)]
def callback( def callback(
@@ -352,4 +361,9 @@ class PlayPlot:
range(xmin, xmax), list(self.data[i]), c="blue" range(xmin, xmax), list(self.data[i]), c="blue"
) )
self.ax[i].set_xlim(xmin, xmax) self.ax[i].set_xlim(xmin, xmax)
if plt is None:
raise DependencyNotInstalled(
"matplotlib is not installed, run `pip install gym[other]`"
)
plt.pause(0.000001) plt.pause(0.000001)

View File

@@ -569,7 +569,7 @@ class AsyncVectorEnv(VectorEnv):
num_errors = self.num_envs - sum(successes) num_errors = self.num_envs - sum(successes)
assert num_errors > 0 assert num_errors > 0
for _ in range(num_errors): for i in range(num_errors):
index, exctype, value = self.error_queue.get() index, exctype, value = self.error_queue.get()
logger.error( logger.error(
f"Received the following error from Worker-{index}: {exctype.__name__}: {value}" f"Received the following error from Worker-{index}: {exctype.__name__}: {value}"
@@ -578,8 +578,9 @@ class AsyncVectorEnv(VectorEnv):
self.parent_pipes[index].close() self.parent_pipes[index].close()
self.parent_pipes[index] = None self.parent_pipes[index] = None
logger.error("Raising the last exception back to the main process.") if i == num_errors - 1:
raise exctype(value) logger.error("Raising the last exception back to the main process.")
raise exctype(value)
def __del__(self): def __del__(self):
"""On deleting the object, checks that the vector environment is closed.""" """On deleting the object, checks that the vector environment is closed."""

View File

@@ -1,9 +1,10 @@
"""A synchronous vector environment.""" """A synchronous vector environment."""
from copy import deepcopy from copy import deepcopy
from typing import Any, Iterator, List, Optional, Sequence, Union from typing import Any, Callable, Iterator, List, Optional, Sequence, Union
import numpy as np import numpy as np
from gym import Env
from gym.spaces import Space from gym.spaces import Space
from gym.vector.utils import concatenate, create_empty_array, iterate from gym.vector.utils import concatenate, create_empty_array, iterate
from gym.vector.vector_env import VectorEnv from gym.vector.vector_env import VectorEnv
@@ -28,7 +29,7 @@ class SyncVectorEnv(VectorEnv):
def __init__( def __init__(
self, self,
env_fns: Iterator[callable], env_fns: Iterator[Callable[[], Env]],
observation_space: Space = None, observation_space: Space = None,
action_space: Space = None, action_space: Space = None,
copy: bool = True, copy: bool = True,

View File

@@ -6,7 +6,8 @@ from gym.vector.utils.shared_memory import (
read_from_shared_memory, read_from_shared_memory,
write_to_shared_memory, write_to_shared_memory,
) )
from gym.vector.utils.spaces import BaseGymSpaces, _BaseGymSpaces, batch_space, iterate from gym.vector.utils.spaces import _BaseGymSpaces # pyright: reportPrivateUsage=false
from gym.vector.utils.spaces import BaseGymSpaces, batch_space, iterate
__all__ = [ __all__ = [
"CloudpickleWrapper", "CloudpickleWrapper",
@@ -17,7 +18,6 @@ __all__ = [
"read_from_shared_memory", "read_from_shared_memory",
"write_to_shared_memory", "write_to_shared_memory",
"BaseGymSpaces", "BaseGymSpaces",
"_BaseGymSpaces",
"batch_space", "batch_space",
"iterate", "iterate",
] ]

View File

@@ -85,6 +85,7 @@ class VectorEnv(gym.Env):
Raises: Raises:
NotImplementedError: VectorEnv does not implement function NotImplementedError: VectorEnv does not implement function
""" """
raise NotImplementedError("VectorEnv does not implement function")
def reset( def reset(
self, self,

View File

@@ -2,13 +2,14 @@
import numpy as np import numpy as np
import gym import gym
from gym.error import DependencyNotInstalled
from gym.spaces import Box from gym.spaces import Box
try: try:
import cv2 import cv2
except ImportError: except ImportError:
cv2 = None raise gym.error.DependencyNotInstalled(
"opencv-python package not installed, run `pip install gym[other]` to get dependencies for atari"
)
class AtariPreprocessing(gym.Wrapper): class AtariPreprocessing(gym.Wrapper):
@@ -60,10 +61,6 @@ class AtariPreprocessing(gym.Wrapper):
ValueError: Disable frame-skipping in the original env ValueError: Disable frame-skipping in the original env
""" """
super().__init__(env) super().__init__(env)
if cv2 is None:
raise DependencyNotInstalled(
"opencv-python package not installed, run `pip install gym[other]` to get dependencies for atari"
)
assert frame_skip > 0 assert frame_skip > 0
assert screen_size > 0 assert screen_size > 0
assert noop_max >= 0 assert noop_max >= 0
@@ -87,6 +84,7 @@ class AtariPreprocessing(gym.Wrapper):
self.scale_obs = scale_obs self.scale_obs = scale_obs
# buffer of most recent two observations for max pooling # buffer of most recent two observations for max pooling
assert isinstance(env.observation_space, Box)
if grayscale_obs: if grayscale_obs:
self.obs_buffer = [ self.obs_buffer = [
np.empty(env.observation_space.shape[:2], dtype=np.uint8), np.empty(env.observation_space.shape[:2], dtype=np.uint8),
@@ -114,7 +112,7 @@ class AtariPreprocessing(gym.Wrapper):
def step(self, action): def step(self, action):
"""Applies the preprocessing for an :meth:`env.step`.""" """Applies the preprocessing for an :meth:`env.step`."""
total_reward = 0.0 total_reward, done, info = 0.0, False, {}
for t in range(self.frame_skip): for t in range(self.frame_skip):
_, reward, done, info = self.env.step(action) _, reward, done, info = self.env.step(action)

View File

@@ -32,8 +32,9 @@ class GrayScaleObservation(gym.ObservationWrapper):
self.keep_dim = keep_dim self.keep_dim = keep_dim
assert ( assert (
len(env.observation_space.shape) == 3 isinstance(self.observation_space, Box)
and env.observation_space.shape[-1] == 3 and len(self.observation_space.shape) == 3
and self.observation_space.shape[-1] == 3
) )
obs_shape = self.observation_space.shape[:2] obs_shape = self.observation_space.shape[:2]

View File

@@ -88,13 +88,16 @@ class HumanRendering(gym.Wrapper):
"pygame is not installed, run `pip install gym[box2d]`" "pygame is not installed, run `pip install gym[box2d]`"
) )
if self.env.render_mode == "rgb_array": if self.env.render_mode == "rgb_array":
last_rgb_array = self.env.render(**kwargs)[-1] last_rgb_array = self.env.render(**kwargs)
assert isinstance(last_rgb_array, list)
last_rgb_array = last_rgb_array[-1]
elif self.env.render_mode == "single_rgb_array": elif self.env.render_mode == "single_rgb_array":
last_rgb_array = self.env.render(**kwargs) last_rgb_array = self.env.render(**kwargs)
else: else:
raise Exception( raise Exception(
"Wrapped environment must have mode 'rgb_array' or 'single_rgb_array'" f"Wrapped environment must have mode 'rgb_array' or 'single_rgb_array', actual render mode: {self.env.render_mode}"
) )
assert isinstance(last_rgb_array, np.ndarray)
if mode == "human": if mode == "human":
rgb_array = np.transpose(last_rgb_array, axes=(1, 0, 2)) rgb_array = np.transpose(last_rgb_array, axes=(1, 0, 2))

View File

@@ -148,7 +148,9 @@ class VideoRecorder: # TODO: remove with gym 1.0
) )
self.output_frames_per_sec = self.backward_compatible_output_frames_per_sec self.output_frames_per_sec = self.backward_compatible_output_frames_per_sec
self.encoder = None # lazily start the process self.encoder: Optional[
Union[TextEncoder, ImageEncoder]
] = None # lazily start the process
self.broken = False self.broken = False
# Dump metadata # Dump metadata
@@ -387,7 +389,7 @@ class ImageEncoder:
InvalidFrame: Expects frame to have shape (w,h,3) or (w,h,4) InvalidFrame: Expects frame to have shape (w,h,3) or (w,h,4)
DependencyNotInstalled: Found neither the ffmpeg nor avconv executables. DependencyNotInstalled: Found neither the ffmpeg nor avconv executables.
""" """
self.proc = None self.proc: Optional[subprocess.Popen] = None
self.output_path = output_path self.output_path = output_path
# Frame shape should be lines-first, so w and h are swapped # Frame shape should be lines-first, so w and h are swapped
h, w, pixfmt = frame_shape h, w, pixfmt = frame_shape
@@ -488,6 +490,7 @@ class ImageEncoder:
f"Your frame has data type {frame.dtype}, but we require uint8 (i.e. RGB values from 0-255)." f"Your frame has data type {frame.dtype}, but we require uint8 (i.e. RGB values from 0-255)."
) )
assert self.proc is not None and self.proc.stdin is not None
try: try:
self.proc.stdin.write(frame.tobytes()) self.proc.stdin.write(frame.tobytes())
except Exception: except Exception:
@@ -496,6 +499,7 @@ class ImageEncoder:
def close(self): def close(self):
"""Closes the Image encoder.""" """Closes the Image encoder."""
assert self.proc is not None and self.proc.stdin is not None
self.proc.stdin.close() self.proc.stdin.close()
ret = self.proc.wait() ret = self.proc.wait()
if ret != 0: if ret != 0:

View File

@@ -81,19 +81,20 @@ class NormalizeObservation(gym.core.Wrapper):
def reset(self, **kwargs): def reset(self, **kwargs):
"""Resets the environment and normalizes the observation.""" """Resets the environment and normalizes the observation."""
return_info = kwargs.get("return_info", False) if kwargs.get("return_info", False):
if return_info:
obs, info = self.env.reset(**kwargs) obs, info = self.env.reset(**kwargs)
if self.is_vector_env:
return self.normalize(obs), info
else:
return self.normalize(np.array([obs]))[0], info
else: else:
obs = self.env.reset(**kwargs) obs = self.env.reset(**kwargs)
if self.is_vector_env:
obs = self.normalize(obs) if self.is_vector_env:
else: return self.normalize(obs)
obs = self.normalize(np.array([obs]))[0] else:
if not return_info: return self.normalize(np.array([obs]))[0]
return obs
else:
return obs, info
def normalize(self, obs): def normalize(self, obs):
"""Normalises the observation using the running mean and variance of the observations.""" """Normalises the observation using the running mean and variance of the observations."""

View File

@@ -49,3 +49,8 @@ class OrderEnforcing(gym.Wrapper):
"set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper." "set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper."
) )
return self.env.render(*args, **kwargs) return self.env.render(*args, **kwargs)
@property
def has_reset(self):
"""Returns if the environment has been reset before."""
return self._has_reset

View File

@@ -120,8 +120,7 @@ class PixelObservationWrapper(gym.ObservationWrapper):
elif self._observation_is_dict: elif self._observation_is_dict:
self.observation_space = copy.deepcopy(wrapped_observation_space) self.observation_space = copy.deepcopy(wrapped_observation_space)
else: else:
self.observation_space = spaces.Dict() self.observation_space = spaces.Dict({STATE_KEY: wrapped_observation_space})
self.observation_space.spaces[STATE_KEY] = wrapped_observation_space
# Extend observation space with pixels. # Extend observation space with pixels.
@@ -129,7 +128,7 @@ class PixelObservationWrapper(gym.ObservationWrapper):
pixels_spaces = {} pixels_spaces = {}
for pixel_key in pixel_keys: for pixel_key in pixel_keys:
pixels = self.env.render(**render_kwargs[pixel_key]) pixels = self.env.render(**render_kwargs[pixel_key])
pixels = pixels[-1] if isinstance(pixels, List) else pixels pixels: np.ndarray = pixels[-1] if isinstance(pixels, List) else pixels
if np.issubdtype(pixels.dtype, np.integer): if np.issubdtype(pixels.dtype, np.integer):
low, high = (0, 255) low, high = (0, 255)

View File

@@ -1,6 +1,7 @@
"""Wrapper that tracks the cumulative rewards and episode lengths.""" """Wrapper that tracks the cumulative rewards and episode lengths."""
import time import time
from collections import deque from collections import deque
from typing import Optional
import numpy as np import numpy as np
@@ -86,8 +87,8 @@ class RecordEpisodeStatistics(gym.Wrapper):
self.num_envs = getattr(env, "num_envs", 1) self.num_envs = getattr(env, "num_envs", 1)
self.t0 = time.perf_counter() self.t0 = time.perf_counter()
self.episode_count = 0 self.episode_count = 0
self.episode_returns = None self.episode_returns: Optional[np.ndarray] = None
self.episode_lengths = None self.episode_lengths: Optional[np.ndarray] = None
self.return_queue = deque(maxlen=deque_size) self.return_queue = deque(maxlen=deque_size)
self.length_queue = deque(maxlen=deque_size) self.length_queue = deque(maxlen=deque_size)
self.is_vector_env = getattr(env, "is_vector_env", False) self.is_vector_env = getattr(env, "is_vector_env", False)

View File

@@ -1,6 +1,6 @@
"""Wrapper for recording videos.""" """Wrapper for recording videos."""
import os import os
from typing import Callable from typing import Callable, Optional
import gym import gym
from gym import logger from gym import logger
@@ -77,7 +77,7 @@ class RecordVideo(gym.Wrapper): # TODO: remove with gym 1.0
self.episode_trigger = episode_trigger self.episode_trigger = episode_trigger
self.step_trigger = step_trigger self.step_trigger = step_trigger
self.video_recorder = None self.video_recorder: Optional[video_recorder.VideoRecorder] = None
self.video_folder = os.path.abspath(video_folder) self.video_folder = os.path.abspath(video_folder)
# Create output folder if needed # Create output folder if needed
@@ -101,6 +101,7 @@ class RecordVideo(gym.Wrapper): # TODO: remove with gym 1.0
"""Reset the environment using kwargs and then starts recording if video enabled.""" """Reset the environment using kwargs and then starts recording if video enabled."""
observations = super().reset(**kwargs) observations = super().reset(**kwargs)
if self.recording: if self.recording:
assert self.video_recorder is not None
self.video_recorder.capture_frame() self.video_recorder.capture_frame()
self.recorded_frames += 1 self.recorded_frames += 1
if self.video_length > 0: if self.video_length > 0:
@@ -148,6 +149,7 @@ class RecordVideo(gym.Wrapper): # TODO: remove with gym 1.0
self.episode_id += 1 self.episode_id += 1
if self.recording: if self.recording:
assert self.video_recorder is not None
self.video_recorder.capture_frame() self.video_recorder.capture_frame()
self.recorded_frames += 1 self.recorded_frames += 1
if self.video_length > 0: if self.video_length > 0:
@@ -168,6 +170,7 @@ class RecordVideo(gym.Wrapper): # TODO: remove with gym 1.0
def close_video_recorder(self): def close_video_recorder(self):
"""Closes the video recorder if currently recording.""" """Closes the video recorder if currently recording."""
if self.recording: if self.recording:
assert self.video_recorder is not None
self.video_recorder.close() self.video_recorder.close()
self.recording = False self.recording = False
self.recorded_frames = 1 self.recorded_frames = 1

View File

@@ -39,7 +39,10 @@ class ResizeObservation(gym.ObservationWrapper):
self.shape = tuple(shape) self.shape = tuple(shape)
obs_shape = self.shape + self.observation_space.shape[2:] assert isinstance(
env.observation_space, Box
), f"Expected the observation space to be Box, actual type: {type(env.observation_space)}"
obs_shape = self.shape + env.observation_space.shape[2:]
self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8) self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)
def observation(self, observation): def observation(self, observation):

View File

@@ -9,40 +9,7 @@ exclude = [
"**/node_modules", "**/node_modules",
"**/__pycache__", "**/__pycache__",
"gym/envs/box2d/bipedal_walker.py", "gym/envs/mujoco/mujoco_env.py",
"gym/envs/box2d/car_racing.py",
"gym/spaces/graph.py",
"gym/envs/mujoco/**",
"gym/utils/play.py",
"gym/vector/async_vector_env.py",
"gym/vector/utils/__init__.py",
"gym/wrappers/atari_preprocessing.py",
"gym/wrappers/gray_scale_observation.py",
"gym/wrappers/human_rendering.py",
"gym/wrappers/normalize.py",
"gym/wrappers/pixel_observation.py",
"gym/wrappers/record_video.py",
"gym/wrappers/monitoring/video_recorder.py",
"gym/wrappers/resize_observation.py",
"tests/envs/test_env_implementation.py",
"tests/utils/test_play.py",
"tests/vector/test_async_vector_env.py",
"tests/vector/test_shared_memory.py",
"tests/vector/test_spaces.py",
"tests/vector/test_sync_vector_env.py",
"tests/vector/test_vector_env.py",
"tests/wrappers/test_gray_scale_observation.py",
"tests/wrappers/test_order_enforcing.py",
"tests/wrappers/test_record_episode_statistics.py",
"tests/wrappers/test_resize_observation.py",
"tests/wrappers/test_time_aware_observation.py",
"tests/wrappers/test_video_recorder.py",
] ]
strict = [ strict = [
@@ -51,6 +18,7 @@ strict = [
typeCheckingMode = "basic" typeCheckingMode = "basic"
pythonVersion = "3.6" pythonVersion = "3.6"
pythonPlatform = "All"
typeshedPath = "typeshed" typeshedPath = "typeshed"
enableTypeIgnoreComments = true enableTypeIgnoreComments = true

View File

@@ -1,5 +1,5 @@
"""Finds all the specs that we can test with""" """Finds all the specs that we can test with"""
from typing import Optional from typing import List, Optional
import numpy as np import numpy as np
@@ -18,22 +18,27 @@ def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]:
# Tries to make all environment to test with # Tries to make all environment to test with
all_testing_initialised_envs = list( all_testing_initialised_envs: List[Optional[gym.Env]] = [
filter(None, [try_make_env(env_spec) for env_spec in gym.envs.registry.values()]) try_make_env(env_spec) for env_spec in gym.envs.registry.values()
) ]
all_testing_initialised_envs: List[gym.Env] = [
env for env in all_testing_initialised_envs if env is not None
]
# All testing, mujoco and gym environment specs # All testing, mujoco and gym environment specs
all_testing_env_specs = [env.spec for env in all_testing_initialised_envs] all_testing_env_specs: List[EnvSpec] = [
mujoco_testing_env_specs = [ env.spec for env in all_testing_initialised_envs
]
mujoco_testing_env_specs: List[EnvSpec] = [
env_spec env_spec
for env_spec in all_testing_env_specs for env_spec in all_testing_env_specs
if "gym.envs.mujoco" in env_spec.entry_point if "gym.envs.mujoco" in env_spec.entry_point
] ]
gym_testing_env_specs = [ gym_testing_env_specs: List[EnvSpec] = [
env_spec env_spec
for env_spec in all_testing_env_specs for env_spec in all_testing_env_specs
if any( if any(
f"gym.{ep}" in env_spec.entry_point f"gym.envs.{ep}" in env_spec.entry_point
for ep in ["box2d", "classic_control", "toy_text"] for ep in ["box2d", "classic_control", "toy_text"]
) )
] ]

View File

@@ -194,7 +194,7 @@ def test_play_loop_real_env():
# first action is 0 because at the first iteration # first action is 0 because at the first iteration
# we can not inject a callback event into play() # we can not inject a callback event into play()
env.step(0) obs, _, _, _ = env.step(0)
for e in keydown_events: for e in keydown_events:
action = keys_to_action[chr(e.key) if str_keys else (e.key,)] action = keys_to_action[chr(e.key) if str_keys else (e.key,)]
obs, _, _, _ = env.step(action) obs, _, _, _ = env.step(action)

View File

@@ -17,22 +17,20 @@ from tests.vector.utils import (
@pytest.mark.parametrize("shared_memory", [True, False]) @pytest.mark.parametrize("shared_memory", [True, False])
def test_create_async_vector_env(shared_memory): def test_create_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(8)] env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
finally:
env.close()
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
assert env.num_envs == 8 assert env.num_envs == 8
env.close()
@pytest.mark.parametrize("shared_memory", [True, False]) @pytest.mark.parametrize("shared_memory", [True, False])
def test_reset_async_vector_env(shared_memory): def test_reset_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(8)] env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations = env.reset() observations = env.reset()
finally:
env.close() env.close()
assert isinstance(env.observation_space, Box) assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray) assert isinstance(observations, np.ndarray)
@@ -71,20 +69,20 @@ def test_reset_async_vector_env(shared_memory):
@pytest.mark.parametrize("use_single_action_space", [True, False]) @pytest.mark.parametrize("use_single_action_space", [True, False])
def test_step_async_vector_env(shared_memory, use_single_action_space): def test_step_async_vector_env(shared_memory, use_single_action_space):
env_fns = [make_env("CartPole-v1", i) for i in range(8)] env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations = env.reset()
assert isinstance(env.single_action_space, Discrete) env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
assert isinstance(env.action_space, MultiDiscrete) observations = env.reset()
if use_single_action_space: assert isinstance(env.single_action_space, Discrete)
actions = [env.single_action_space.sample() for _ in range(8)] assert isinstance(env.action_space, MultiDiscrete)
else:
actions = env.action_space.sample() if use_single_action_space:
observations, rewards, dones, _ = env.step(actions) actions = [env.single_action_space.sample() for _ in range(8)]
finally: else:
env.close() actions = env.action_space.sample()
observations, rewards, dones, _ = env.step(actions)
env.close()
assert isinstance(env.observation_space, Box) assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray) assert isinstance(observations, np.ndarray)
@@ -106,13 +104,13 @@ def test_step_async_vector_env(shared_memory, use_single_action_space):
@pytest.mark.parametrize("shared_memory", [True, False]) @pytest.mark.parametrize("shared_memory", [True, False])
def test_call_async_vector_env(shared_memory): def test_call_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i, render_mode="rgb_array") for i in range(4)] env_fns = [make_env("CartPole-v1", i, render_mode="rgb_array") for i in range(4)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
_ = env.reset() _ = env.reset()
images = env.call("render") images = env.call("render")
gravity = env.call("gravity") gravity = env.call("gravity")
finally:
env.close() env.close()
assert isinstance(images, tuple) assert isinstance(images, tuple)
assert len(images) == 4 assert len(images) == 4
@@ -130,79 +128,81 @@ def test_call_async_vector_env(shared_memory):
@pytest.mark.parametrize("shared_memory", [True, False]) @pytest.mark.parametrize("shared_memory", [True, False])
def test_set_attr_async_vector_env(shared_memory): def test_set_attr_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(4)] env_fns = [make_env("CartPole-v1", i) for i in range(4)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
env.set_attr("gravity", [9.81, 3.72, 8.87, 1.62]) env.set_attr("gravity", [9.81, 3.72, 8.87, 1.62])
gravity = env.get_attr("gravity") gravity = env.get_attr("gravity")
assert gravity == (9.81, 3.72, 8.87, 1.62) assert gravity == (9.81, 3.72, 8.87, 1.62)
finally:
env.close() env.close()
@pytest.mark.parametrize("shared_memory", [True, False]) @pytest.mark.parametrize("shared_memory", [True, False])
def test_copy_async_vector_env(shared_memory): def test_copy_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(8)] env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=True) # TODO, these tests do nothing, understand the purpose of the tests and fix them
observations = env.reset() env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=True)
observations[0] = 0 observations = env.reset()
finally: observations[0] = 0
env.close()
env.close()
@pytest.mark.parametrize("shared_memory", [True, False]) @pytest.mark.parametrize("shared_memory", [True, False])
def test_no_copy_async_vector_env(shared_memory): def test_no_copy_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(8)] env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=False) # TODO, these tests do nothing, understand the purpose of the tests and fix them
observations = env.reset() env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=False)
observations[0] = 0 observations = env.reset()
finally: observations[0] = 0
env.close()
env.close()
@pytest.mark.parametrize("shared_memory", [True, False]) @pytest.mark.parametrize("shared_memory", [True, False])
def test_reset_timeout_async_vector_env(shared_memory): def test_reset_timeout_async_vector_env(shared_memory):
env_fns = [make_slow_env(0.3, i) for i in range(4)] env_fns = [make_slow_env(0.3, i) for i in range(4)]
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
with pytest.raises(TimeoutError): with pytest.raises(TimeoutError):
try: env.reset_async()
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) env.reset_wait(timeout=0.1)
env.reset_async()
env.reset_wait(timeout=0.1) env.close(terminate=True)
finally:
env.close(terminate=True)
@pytest.mark.parametrize("shared_memory", [True, False]) @pytest.mark.parametrize("shared_memory", [True, False])
def test_step_timeout_async_vector_env(shared_memory): def test_step_timeout_async_vector_env(shared_memory):
env_fns = [make_slow_env(0.0, i) for i in range(4)] env_fns = [make_slow_env(0.0, i) for i in range(4)]
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
with pytest.raises(TimeoutError): with pytest.raises(TimeoutError):
try: env.reset()
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) env.step_async(np.array([0.1, 0.1, 0.3, 0.1]))
env.reset() observations, rewards, dones, _ = env.step_wait(timeout=0.1)
env.step_async([0.1, 0.1, 0.3, 0.1]) env.close(terminate=True)
observations, rewards, dones, _ = env.step_wait(timeout=0.1)
finally:
env.close(terminate=True)
@pytest.mark.filterwarnings("ignore::UserWarning") @pytest.mark.filterwarnings("ignore::UserWarning")
@pytest.mark.parametrize("shared_memory", [True, False]) @pytest.mark.parametrize("shared_memory", [True, False])
def test_reset_out_of_order_async_vector_env(shared_memory): def test_reset_out_of_order_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(4)] env_fns = [make_env("CartPole-v1", i) for i in range(4)]
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
with pytest.raises(NoAsyncCallError): with pytest.raises(NoAsyncCallError):
try: try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
env.reset_wait() env.reset_wait()
except NoAsyncCallError as exception: except NoAsyncCallError as exception:
assert exception.name == "reset" assert exception.name == "reset"
raise raise
finally:
env.close(terminate=True)
env.close(terminate=True)
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
with pytest.raises(AlreadyPendingCallError): with pytest.raises(AlreadyPendingCallError):
try: try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
actions = env.action_space.sample() actions = env.action_space.sample()
env.reset() env.reset()
env.step_async(actions) env.step_async(actions)
@@ -210,37 +210,38 @@ def test_reset_out_of_order_async_vector_env(shared_memory):
except NoAsyncCallError as exception: except NoAsyncCallError as exception:
assert exception.name == "step" assert exception.name == "step"
raise raise
finally:
env.close(terminate=True) env.close(terminate=True)
@pytest.mark.filterwarnings("ignore::UserWarning") @pytest.mark.filterwarnings("ignore::UserWarning")
@pytest.mark.parametrize("shared_memory", [True, False]) @pytest.mark.parametrize("shared_memory", [True, False])
def test_step_out_of_order_async_vector_env(shared_memory): def test_step_out_of_order_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(4)] env_fns = [make_env("CartPole-v1", i) for i in range(4)]
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
with pytest.raises(NoAsyncCallError): with pytest.raises(NoAsyncCallError):
try: try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) env.action_space.sample()
actions = env.action_space.sample() env.reset()
observations = env.reset() env.step_wait()
observations, rewards, dones, infos = env.step_wait()
except AlreadyPendingCallError as exception: except AlreadyPendingCallError as exception:
assert exception.name == "step" assert exception.name == "step"
raise raise
finally:
env.close(terminate=True)
env.close(terminate=True)
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
with pytest.raises(AlreadyPendingCallError): with pytest.raises(AlreadyPendingCallError):
try: try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
actions = env.action_space.sample() actions = env.action_space.sample()
env.reset_async() env.reset_async()
env.step_async(actions) env.step_async(actions)
except AlreadyPendingCallError as exception: except AlreadyPendingCallError as exception:
assert exception.name == "reset" assert exception.name == "reset"
raise raise
finally:
env.close(terminate=True) env.close(terminate=True)
@pytest.mark.parametrize("shared_memory", [True, False]) @pytest.mark.parametrize("shared_memory", [True, False])
@@ -265,17 +266,17 @@ def test_check_spaces_async_vector_env(shared_memory):
def test_custom_space_async_vector_env(): def test_custom_space_async_vector_env():
env_fns = [make_custom_space_env(i) for i in range(4)] env_fns = [make_custom_space_env(i) for i in range(4)]
try:
env = AsyncVectorEnv(env_fns, shared_memory=False)
reset_observations = env.reset()
assert isinstance(env.single_action_space, CustomSpace) env = AsyncVectorEnv(env_fns, shared_memory=False)
assert isinstance(env.action_space, Tuple) reset_observations = env.reset()
actions = ("action-2", "action-3", "action-5", "action-7") assert isinstance(env.single_action_space, CustomSpace)
step_observations, rewards, dones, _ = env.step(actions) assert isinstance(env.action_space, Tuple)
finally:
env.close() actions = ("action-2", "action-3", "action-5", "action-7")
step_observations, rewards, dones, _ = env.step(actions)
env.close()
assert isinstance(env.single_observation_space, CustomSpace) assert isinstance(env.single_observation_space, CustomSpace)
assert isinstance(env.observation_space, Tuple) assert isinstance(env.observation_space, Tuple)

View File

@@ -116,6 +116,7 @@ def test_batch_space_custom_space(space, expected_batch_space_4):
def test_iterate(space, batch_space): def test_iterate(space, batch_space):
items = batch_space.sample() items = batch_space.sample()
iterator = iterate(batch_space, items) iterator = iterate(batch_space, items)
i = 0
for i, item in enumerate(iterator): for i, item in enumerate(iterator):
assert item in space assert item in space
assert i == 3 assert i == 3
@@ -129,6 +130,7 @@ def test_iterate(space, batch_space):
def test_iterate_custom_space(space, batch_space): def test_iterate_custom_space(space, batch_space):
items = batch_space.sample() items = batch_space.sample()
iterator = iterate(batch_space, items) iterator = iterate(batch_space, items)
i = 0
for i, item in enumerate(iterator): for i, item in enumerate(iterator):
assert item in space assert item in space
assert i == 3 assert i == 3

View File

@@ -15,21 +15,17 @@ from tests.vector.utils import (
def test_create_sync_vector_env(): def test_create_sync_vector_env():
env_fns = [make_env("FrozenLake-v1", i) for i in range(8)] env_fns = [make_env("FrozenLake-v1", i) for i in range(8)]
try: env = SyncVectorEnv(env_fns)
env = SyncVectorEnv(env_fns) env.close()
finally:
env.close()
assert env.num_envs == 8 assert env.num_envs == 8
def test_reset_sync_vector_env(): def test_reset_sync_vector_env():
env_fns = [make_env("CartPole-v1", i) for i in range(8)] env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try: env = SyncVectorEnv(env_fns)
env = SyncVectorEnv(env_fns) observations = env.reset()
observations = env.reset() env.close()
finally:
env.close()
assert isinstance(env.observation_space, Box) assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray) assert isinstance(observations, np.ndarray)
@@ -39,11 +35,9 @@ def test_reset_sync_vector_env():
del observations del observations
try: env = SyncVectorEnv(env_fns)
env = SyncVectorEnv(env_fns) observations = env.reset(return_info=False)
observations = env.reset(return_info=False) env.close()
finally:
env.close()
assert isinstance(env.observation_space, Box) assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray) assert isinstance(observations, np.ndarray)
@@ -54,11 +48,10 @@ def test_reset_sync_vector_env():
del observations del observations
env_fns = [make_env("CartPole-v1", i) for i in range(8)] env_fns = [make_env("CartPole-v1", i) for i in range(8)]
try:
env = SyncVectorEnv(env_fns) env = SyncVectorEnv(env_fns)
observations, infos = env.reset(return_info=True) observations, infos = env.reset(return_info=True)
finally: env.close()
env.close()
assert isinstance(env.observation_space, Box) assert isinstance(env.observation_space, Box)
assert isinstance(observations, np.ndarray) assert isinstance(observations, np.ndarray)
@@ -72,20 +65,20 @@ def test_reset_sync_vector_env():
@pytest.mark.parametrize("use_single_action_space", [True, False]) @pytest.mark.parametrize("use_single_action_space", [True, False])
def test_step_sync_vector_env(use_single_action_space): def test_step_sync_vector_env(use_single_action_space):
env_fns = [make_env("FrozenLake-v1", i) for i in range(8)] env_fns = [make_env("FrozenLake-v1", i) for i in range(8)]
try:
env = SyncVectorEnv(env_fns)
observations = env.reset()
assert isinstance(env.single_action_space, Discrete) env = SyncVectorEnv(env_fns)
assert isinstance(env.action_space, MultiDiscrete) observations = env.reset()
if use_single_action_space: assert isinstance(env.single_action_space, Discrete)
actions = [env.single_action_space.sample() for _ in range(8)] assert isinstance(env.action_space, MultiDiscrete)
else:
actions = env.action_space.sample() if use_single_action_space:
observations, rewards, dones, _ = env.step(actions) actions = [env.single_action_space.sample() for _ in range(8)]
finally: else:
env.close() actions = env.action_space.sample()
observations, rewards, dones, _ = env.step(actions)
env.close()
assert isinstance(env.observation_space, MultiDiscrete) assert isinstance(env.observation_space, MultiDiscrete)
assert isinstance(observations, np.ndarray) assert isinstance(observations, np.ndarray)
@@ -106,13 +99,13 @@ def test_step_sync_vector_env(use_single_action_space):
def test_call_sync_vector_env(): def test_call_sync_vector_env():
env_fns = [make_env("CartPole-v1", i, render_mode="rgb_array") for i in range(4)] env_fns = [make_env("CartPole-v1", i, render_mode="rgb_array") for i in range(4)]
try:
env = SyncVectorEnv(env_fns) env = SyncVectorEnv(env_fns)
_ = env.reset() _ = env.reset()
images = env.call("render") images = env.call("render")
gravity = env.call("gravity") gravity = env.call("gravity")
finally:
env.close() env.close()
assert isinstance(images, tuple) assert isinstance(images, tuple)
assert len(images) == 4 assert len(images) == 4
@@ -129,13 +122,13 @@ def test_call_sync_vector_env():
def test_set_attr_sync_vector_env(): def test_set_attr_sync_vector_env():
env_fns = [make_env("CartPole-v1", i) for i in range(4)] env_fns = [make_env("CartPole-v1", i) for i in range(4)]
try:
env = SyncVectorEnv(env_fns) env = SyncVectorEnv(env_fns)
env.set_attr("gravity", [9.81, 3.72, 8.87, 1.62]) env.set_attr("gravity", [9.81, 3.72, 8.87, 1.62])
gravity = env.get_attr("gravity") gravity = env.get_attr("gravity")
assert gravity == (9.81, 3.72, 8.87, 1.62) assert gravity == (9.81, 3.72, 8.87, 1.62)
finally:
env.close() env.close()
def test_check_spaces_sync_vector_env(): def test_check_spaces_sync_vector_env():
@@ -150,17 +143,17 @@ def test_check_spaces_sync_vector_env():
def test_custom_space_sync_vector_env(): def test_custom_space_sync_vector_env():
env_fns = [make_custom_space_env(i) for i in range(4)] env_fns = [make_custom_space_env(i) for i in range(4)]
try:
env = SyncVectorEnv(env_fns)
reset_observations = env.reset()
assert isinstance(env.single_action_space, CustomSpace) env = SyncVectorEnv(env_fns)
assert isinstance(env.action_space, Tuple) reset_observations = env.reset()
actions = ("action-2", "action-3", "action-5", "action-7") assert isinstance(env.single_action_space, CustomSpace)
step_observations, rewards, dones, _ = env.step(actions) assert isinstance(env.action_space, Tuple)
finally:
env.close() actions = ("action-2", "action-3", "action-5", "action-7")
step_observations, rewards, dones, _ = env.step(actions)
env.close()
assert isinstance(env.single_observation_space, CustomSpace) assert isinstance(env.single_observation_space, CustomSpace)
assert isinstance(env.observation_space, Tuple) assert isinstance(env.observation_space, Tuple)

View File

@@ -12,42 +12,41 @@ from tests.vector.utils import CustomSpace, make_env
def test_vector_env_equal(shared_memory): def test_vector_env_equal(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(4)] env_fns = [make_env("CartPole-v1", i) for i in range(4)]
num_steps = 100 num_steps = 100
try:
async_env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
sync_env = SyncVectorEnv(env_fns)
assert async_env.num_envs == sync_env.num_envs async_env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
assert async_env.observation_space == sync_env.observation_space sync_env = SyncVectorEnv(env_fns)
assert async_env.single_observation_space == sync_env.single_observation_space
assert async_env.action_space == sync_env.action_space assert async_env.num_envs == sync_env.num_envs
assert async_env.single_action_space == sync_env.single_action_space assert async_env.observation_space == sync_env.observation_space
assert async_env.single_observation_space == sync_env.single_observation_space
assert async_env.action_space == sync_env.action_space
assert async_env.single_action_space == sync_env.single_action_space
async_observations = async_env.reset(seed=0)
sync_observations = sync_env.reset(seed=0)
assert np.all(async_observations == sync_observations)
for _ in range(num_steps):
actions = async_env.action_space.sample()
assert actions in sync_env.action_space
# fmt: off
async_observations, async_rewards, async_dones, async_infos = async_env.step(actions)
sync_observations, sync_rewards, sync_dones, sync_infos = sync_env.step(actions)
# fmt: on
if any(sync_dones):
assert "terminal_observation" in async_infos
assert "_terminal_observation" in async_infos
assert "terminal_observation" in sync_infos
assert "_terminal_observation" in sync_infos
async_observations = async_env.reset(seed=0)
sync_observations = sync_env.reset(seed=0)
assert np.all(async_observations == sync_observations) assert np.all(async_observations == sync_observations)
assert np.all(async_rewards == sync_rewards)
assert np.all(async_dones == sync_dones)
for _ in range(num_steps): async_env.close()
actions = async_env.action_space.sample() sync_env.close()
assert actions in sync_env.action_space
# fmt: off
async_observations, async_rewards, async_dones, async_infos = async_env.step(actions)
sync_observations, sync_rewards, sync_dones, sync_infos = sync_env.step(actions)
# fmt: on
if any(sync_dones):
assert "terminal_observation" in async_infos
assert "_terminal_observation" in async_infos
assert "terminal_observation" in sync_infos
assert "_terminal_observation" in sync_infos
assert np.all(async_observations == sync_observations)
assert np.all(async_rewards == sync_rewards)
assert np.all(async_dones == sync_dones)
finally:
async_env.close()
sync_env.close()
def test_custom_space_vector_env(): def test_custom_space_vector_env():

View File

@@ -2,6 +2,7 @@ import numpy as np
import pytest import pytest
import gym import gym
from gym import spaces
from gym.wrappers import AtariPreprocessing, GrayScaleObservation from gym.wrappers import AtariPreprocessing, GrayScaleObservation
pytest.importorskip("gym.envs.atari") pytest.importorskip("gym.envs.atari")
@@ -20,8 +21,12 @@ def test_gray_scale_observation(env_id, keep_dim):
gym.make(env_id, disable_env_checker=True), screen_size=84, grayscale_obs=False gym.make(env_id, disable_env_checker=True), screen_size=84, grayscale_obs=False
) )
wrapped_env = GrayScaleObservation(rgb_env, keep_dim=keep_dim) wrapped_env = GrayScaleObservation(rgb_env, keep_dim=keep_dim)
assert isinstance(rgb_env.observation_space, spaces.Box)
assert rgb_env.observation_space.shape[-1] == 3 assert rgb_env.observation_space.shape[-1] == 3
assert isinstance(wrapped_env.observation_space, spaces.Box)
seed = 0 seed = 0
gray_obs = gray_env.reset(seed=seed) gray_obs = gray_env.reset(seed=seed)

View File

@@ -26,16 +26,16 @@ def test_order_enforcing():
# Assert that the order enforcing works for step and render before reset # Assert that the order enforcing works for step and render before reset
order_enforced_env = OrderEnforcing(env) order_enforced_env = OrderEnforcing(env)
assert order_enforced_env._has_reset is False assert order_enforced_env.has_reset is False
with pytest.raises(ResetNeeded): with pytest.raises(ResetNeeded):
order_enforced_env.step(0) order_enforced_env.step(0)
with pytest.raises(ResetNeeded): with pytest.raises(ResetNeeded):
order_enforced_env.render(mode="rgb_array") order_enforced_env.render(mode="rgb_array")
assert order_enforced_env._has_reset is False assert order_enforced_env.has_reset is False
# Assert that the Assertion errors are not raised after reset # Assert that the Assertion errors are not raised after reset
order_enforced_env.reset() order_enforced_env.reset()
assert order_enforced_env._has_reset is True assert order_enforced_env.has_reset is True
order_enforced_env.step(0) order_enforced_env.step(0)
order_enforced_env.render(mode="rgb_array") order_enforced_env.render(mode="rgb_array")

View File

@@ -14,6 +14,7 @@ def test_record_episode_statistics(env_id, deque_size):
for n in range(5): for n in range(5):
env.reset() env.reset()
assert env.episode_returns is not None and env.episode_lengths is not None
assert env.episode_returns[0] == 0.0 assert env.episode_returns[0] == 0.0
assert env.episode_lengths[0] == 0 assert env.episode_lengths[0] == 0
for t in range(env.spec.max_episode_steps): for t in range(env.spec.max_episode_steps):

View File

@@ -1,6 +1,7 @@
import pytest import pytest
import gym import gym
from gym import spaces
from gym.wrappers import ResizeObservation from gym.wrappers import ResizeObservation
pytest.importorskip("gym.envs.atari") pytest.importorskip("gym.envs.atari")
@@ -14,6 +15,7 @@ def test_resize_observation(env_id, shape):
env = gym.make(env_id, disable_env_checker=True) env = gym.make(env_id, disable_env_checker=True)
env = ResizeObservation(env, shape) env = ResizeObservation(env, shape)
assert isinstance(env.observation_space, spaces.Box)
assert env.observation_space.shape[-1] == 3 assert env.observation_space.shape[-1] == 3
obs = env.reset() obs = env.reset()
if isinstance(shape, int): if isinstance(shape, int):

View File

@@ -1,6 +1,7 @@
import pytest import pytest
import gym import gym
from gym import spaces
from gym.wrappers import TimeAwareObservation from gym.wrappers import TimeAwareObservation
@@ -9,6 +10,8 @@ def test_time_aware_observation(env_id):
env = gym.make(env_id, disable_env_checker=True) env = gym.make(env_id, disable_env_checker=True)
wrapped_env = TimeAwareObservation(env) wrapped_env = TimeAwareObservation(env)
assert isinstance(env.observation_space, spaces.Box)
assert isinstance(wrapped_env.observation_space, spaces.Box)
assert wrapped_env.observation_space.shape[0] == env.observation_space.shape[0] + 1 assert wrapped_env.observation_space.shape[0] == env.observation_space.shape[0] + 1
obs = env.reset() obs = env.reset()

View File

@@ -33,9 +33,10 @@ def test_record_simple():
rec = VideoRecorder(env) rec = VideoRecorder(env)
env.reset() env.reset()
rec.capture_frame() rec.capture_frame()
assert rec.encoder is not None
proc = rec.encoder.proc proc = rec.encoder.proc
assert proc.poll() is None # subprocess is running assert proc is not None and proc.poll() is None # subprocess is running
rec.close() rec.close()
@@ -55,9 +56,10 @@ def test_autoclose():
rec.capture_frame() rec.capture_frame()
rec_path = rec.path rec_path = rec.path
assert rec.encoder is not None
proc = rec.encoder.proc proc = rec.encoder.proc
assert proc.poll() is None # subprocess is running assert proc is not None and proc.poll() is None # subprocess is running
# The function ends without an explicit `rec.close()` call # The function ends without an explicit `rec.close()` call
# The Python interpreter will implicitly do `del rec` on garbage cleaning # The Python interpreter will implicitly do `del rec` on garbage cleaning
@@ -68,7 +70,7 @@ def test_autoclose():
gc.collect() # do explicit garbage collection for test gc.collect() # do explicit garbage collection for test
time.sleep(5) # wait for subprocess exiting time.sleep(5) # wait for subprocess exiting
assert proc.poll() is not None # subprocess is terminated assert proc is not None and proc.poll() is not None # subprocess is terminated
assert os.path.exists(rec_path) assert os.path.exists(rec_path)
f = open(rec_path) f = open(rec_path)
assert os.fstat(f.fileno()).st_size > 100 assert os.fstat(f.fileno()).st_size > 100