mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-30 01:50:19 +00:00
Full type hinting (#2942)
* Allows a new RNG to be generated with seed=-1 and updated env_checker to fix bug if environment doesn't use np_random in reset
* Revert "fixed `gym.vector.make` where the checker was being applied in the opposite case than was intended to (#2871)"
This reverts commit 519dfd9117
.
* Remove bad pushed commits
* Fixed spelling in core.py
* Pins pytest to the last py 3.6 version
* Allow Box automatic scalar shape
* Add test box and change default from () to (1,)
* update Box shape inference with more strict checking
* Update the box shape and add check on the custom Box shape
* Removed incorrect shape type and assert shape code
* Update the Box and associated tests
* Remove all folders and files from pyright exclude
* Revert issues
* Push RedTachyon code review
* Add Python Platform
* Remove play from pyright check
* Fixed CI issues
* remove mujoco env type hinting
* Fixed pixel observation test
* Added some new type hints
* Fixed CI errors
* Fixed CI errors
* Remove play.py from exlucde pyright
* Fixed pyright issues
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper."""
|
||||
import sys
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
@@ -17,6 +18,9 @@ from gym.logger import deprecation, warn
|
||||
from gym.utils import seeding
|
||||
from gym.utils.seeding import RandomNumberGenerator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gym.envs.registration import EnvSpec
|
||||
|
||||
if sys.version_info[0:2] == (3, 6):
|
||||
warn(
|
||||
"Gym minimally supports python 3.6 as the python foundation not longer supports the version, please update your version to 3.7+"
|
||||
@@ -106,7 +110,7 @@ class Env(Generic[ObsType, ActType], metaclass=decorator):
|
||||
metadata = {"render_modes": []}
|
||||
render_mode = None # define render_mode if your environment supports rendering
|
||||
reward_range = (-float("inf"), float("inf"))
|
||||
spec = None
|
||||
spec: "EnvSpec" = None
|
||||
|
||||
# Set these in ALL subclasses
|
||||
action_space: spaces.Space[ActType]
|
||||
|
@@ -1,7 +1,7 @@
|
||||
__credits__ = ["Andrea PIERRÉ"]
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -25,6 +25,9 @@ except ImportError:
|
||||
raise DependencyNotInstalled("box2D is not installed, run `pip install gym[box2d]`")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pygame
|
||||
|
||||
FPS = 50
|
||||
SCALE = 30.0 # affects how fast-paced the game is, forces should be adjusted as well
|
||||
|
||||
@@ -170,8 +173,8 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
self.isopen = True
|
||||
|
||||
self.world = Box2D.b2World()
|
||||
self.terrain = None
|
||||
self.hull = None
|
||||
self.terrain: List[Box2D.b2Body] = []
|
||||
self.hull: Optional[Box2D.b2Body] = None
|
||||
|
||||
self.prev_shaping = None
|
||||
|
||||
@@ -256,7 +259,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
|
||||
self.render_mode = render_mode
|
||||
self.renderer = Renderer(self.render_mode, self._render)
|
||||
self.screen = None
|
||||
self.screen: Optional[pygame.Surface] = None
|
||||
self.clock = None
|
||||
|
||||
def _destroy(self):
|
||||
@@ -283,6 +286,9 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
self.terrain = []
|
||||
self.terrain_x = []
|
||||
self.terrain_y = []
|
||||
|
||||
stair_steps, stair_width, stair_height = 0, 0, 0
|
||||
original_y = 0
|
||||
for i in range(TERRAIN_LENGTH):
|
||||
x = i * TERRAIN_STEP
|
||||
self.terrain_x.append(x)
|
||||
@@ -448,8 +454,8 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
(self.np_random.uniform(-INITIAL_RANDOM, INITIAL_RANDOM), 0), True
|
||||
)
|
||||
|
||||
self.legs = []
|
||||
self.joints = []
|
||||
self.legs: List[Box2D.b2Body] = []
|
||||
self.joints: List[Box2D.b2RevoluteJoint] = []
|
||||
for i in [-1, +1]:
|
||||
leg = self.world.CreateDynamicBody(
|
||||
position=(init_x, init_y - LEG_H / 2 - LEG_DOWN),
|
||||
@@ -514,6 +520,8 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
return self.step(np.array([0, 0, 0, 0]))[0], {}
|
||||
|
||||
def step(self, action: np.ndarray):
|
||||
assert self.hull is not None
|
||||
|
||||
# self.hull.ApplyForceToCenter((0, 20), True) -- Uncomment this to receive a bit of stability help
|
||||
control_speed = False # Should be easier as well
|
||||
if control_speed:
|
||||
@@ -737,6 +745,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
self.surf = pygame.transform.flip(self.surf, False, True)
|
||||
|
||||
if mode == "human":
|
||||
assert self.screen is not None
|
||||
self.screen.blit(self.surf, (-self.scroll * SCALE, 0))
|
||||
pygame.event.pump()
|
||||
self.clock.tick(self.metadata["render_fps"])
|
||||
|
@@ -9,6 +9,7 @@ Created by Oleg Klimov
|
||||
|
||||
import math
|
||||
|
||||
import Box2D
|
||||
import numpy as np
|
||||
|
||||
from gym.error import DependencyNotInstalled
|
||||
@@ -48,8 +49,8 @@ MUD_COLOR = (102, 102, 0)
|
||||
|
||||
class Car:
|
||||
def __init__(self, world, init_angle, init_x, init_y):
|
||||
self.world = world
|
||||
self.hull = self.world.CreateDynamicBody(
|
||||
self.world: Box2D.b2World = world
|
||||
self.hull: Box2D.b2Body = self.world.CreateDynamicBody(
|
||||
position=(init_x, init_y),
|
||||
angle=init_angle,
|
||||
fixtures=[
|
||||
|
@@ -197,14 +197,14 @@ class CarRacing(gym.Env, EzPickle):
|
||||
|
||||
self.contactListener_keepref = FrictionDetector(self, lap_complete_percent)
|
||||
self.world = Box2D.b2World((0, 0), contactListener=self.contactListener_keepref)
|
||||
self.screen = None
|
||||
self.screen: Optional[pygame.Surface] = None
|
||||
self.surf = None
|
||||
self.clock = None
|
||||
self.isopen = True
|
||||
self.invisible_state_window = None
|
||||
self.invisible_video_window = None
|
||||
self.road = None
|
||||
self.car = None
|
||||
self.car: Optional[Car] = None
|
||||
self.reward = 0.0
|
||||
self.prev_reward = 0.0
|
||||
self.verbose = verbose
|
||||
@@ -237,6 +237,7 @@ class CarRacing(gym.Env, EzPickle):
|
||||
for t in self.road:
|
||||
self.world.DestroyBody(t)
|
||||
self.road = []
|
||||
assert self.car is not None
|
||||
self.car.destroy()
|
||||
|
||||
def _init_colors(self):
|
||||
@@ -502,6 +503,7 @@ class CarRacing(gym.Env, EzPickle):
|
||||
return self.step(None)[0], {}
|
||||
|
||||
def step(self, action: Union[np.ndarray, int]):
|
||||
assert self.car is not None
|
||||
if action is not None:
|
||||
if self.continuous:
|
||||
self.car.steer(-action[0])
|
||||
@@ -576,6 +578,7 @@ class CarRacing(gym.Env, EzPickle):
|
||||
|
||||
self.surf = pygame.Surface((WINDOW_W, WINDOW_H))
|
||||
|
||||
assert self.car is not None
|
||||
# computing transformations
|
||||
angle = -self.car.hull.angle
|
||||
# Animating first second zoom.
|
||||
@@ -608,6 +611,7 @@ class CarRacing(gym.Env, EzPickle):
|
||||
if mode == "human":
|
||||
pygame.event.pump()
|
||||
self.clock.tick(self.metadata["render_fps"])
|
||||
assert self.screen is not None
|
||||
self.screen.fill(0)
|
||||
self.screen.blit(self.surf, (0, 0))
|
||||
pygame.display.flip()
|
||||
@@ -682,6 +686,7 @@ class CarRacing(gym.Env, EzPickle):
|
||||
((place + 0) * s, H - 2 * h),
|
||||
]
|
||||
|
||||
assert self.car is not None
|
||||
true_speed = np.sqrt(
|
||||
np.square(self.car.hull.linearVelocity[0])
|
||||
+ np.square(self.car.hull.linearVelocity[1])
|
||||
|
@@ -2,7 +2,7 @@ __credits__ = ["Andrea PIERRÉ"]
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -25,6 +25,11 @@ try:
|
||||
except ImportError:
|
||||
raise DependencyNotInstalled("box2d is not installed, run `pip install gym[box2d]`")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pygame
|
||||
|
||||
|
||||
FPS = 50
|
||||
SCALE = 30.0 # affects how fast-paced the game is, forces should be adjusted as well
|
||||
|
||||
@@ -215,7 +220,7 @@ class LunarLander(gym.Env, EzPickle):
|
||||
self.wind_idx = np.random.randint(-9999, 9999)
|
||||
self.torque_idx = np.random.randint(-9999, 9999)
|
||||
|
||||
self.screen = None
|
||||
self.screen: pygame.Surface = None
|
||||
self.clock = None
|
||||
self.isopen = True
|
||||
self.world = Box2D.b2World(gravity=(0, gravity))
|
||||
@@ -427,6 +432,8 @@ class LunarLander(gym.Env, EzPickle):
|
||||
self.world.DestroyBody(self.particles.pop(0))
|
||||
|
||||
def step(self, action):
|
||||
assert self.lander is not None
|
||||
|
||||
# Update wind
|
||||
assert self.lander is not None, "You forgot to call reset()"
|
||||
if self.enable_wind and not (
|
||||
@@ -604,10 +611,6 @@ class LunarLander(gym.Env, EzPickle):
|
||||
if self.clock is None:
|
||||
self.clock = pygame.time.Clock()
|
||||
|
||||
assert (
|
||||
self.screen is not None
|
||||
), "Something went wrong with pygame, there is no screen to render"
|
||||
|
||||
self.surf = pygame.Surface((VIEWPORT_W, VIEWPORT_H))
|
||||
|
||||
pygame.transform.scale(self.surf, (SCALE, SCALE))
|
||||
|
@@ -79,4 +79,5 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return self._get_obs()
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
self.viewer.cam.distance = self.model.stat.extent * 0.5
|
||||
|
@@ -173,6 +173,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return observation
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
getattr(self.viewer.cam, key)[:] = value
|
||||
|
@@ -336,6 +336,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return observation
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
getattr(self.viewer.cam, key)[:] = value
|
||||
|
@@ -60,4 +60,5 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return self._get_obs()
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
self.viewer.cam.distance = self.model.stat.extent * 0.5
|
||||
|
@@ -118,6 +118,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return observation
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
getattr(self.viewer.cam, key)[:] = value
|
||||
|
@@ -231,6 +231,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return observation
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
getattr(self.viewer.cam, key)[:] = value
|
||||
|
@@ -66,6 +66,7 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return self._get_obs()
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
self.viewer.cam.trackbodyid = 2
|
||||
self.viewer.cam.distance = self.model.stat.extent * 0.75
|
||||
self.viewer.cam.lookat[2] = 1.15
|
||||
|
@@ -163,6 +163,7 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return observation
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
getattr(self.viewer.cam, key)[:] = value
|
||||
|
@@ -279,6 +279,7 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return observation
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
getattr(self.viewer.cam, key)[:] = value
|
||||
|
@@ -92,6 +92,7 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return self._get_obs()
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
self.viewer.cam.trackbodyid = 1
|
||||
self.viewer.cam.distance = self.model.stat.extent * 1.0
|
||||
self.viewer.cam.lookat[2] = 2.0
|
||||
|
@@ -185,6 +185,7 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return observation
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
getattr(self.viewer.cam, key)[:] = value
|
||||
|
@@ -358,6 +358,7 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return observation
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
getattr(self.viewer.cam, key)[:] = value
|
||||
|
@@ -83,6 +83,7 @@ class HumanoidStandupEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return self._get_obs()
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
self.viewer.cam.trackbodyid = 1
|
||||
self.viewer.cam.distance = self.model.stat.extent * 1.0
|
||||
self.viewer.cam.lookat[2] = 0.8925
|
||||
|
@@ -254,6 +254,7 @@ class HumanoidStandupEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return self._get_obs()
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
self.viewer.cam.trackbodyid = 1
|
||||
self.viewer.cam.distance = self.model.stat.extent * 1.0
|
||||
self.viewer.cam.lookat[2] = 0.8925
|
||||
|
@@ -64,6 +64,7 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return self._get_obs()
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
v = self.viewer
|
||||
v.cam.trackbodyid = 0
|
||||
v.cam.distance = self.model.stat.extent * 0.5
|
||||
|
@@ -169,6 +169,7 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return self._get_obs()
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
v = self.viewer
|
||||
v.cam.trackbodyid = 0
|
||||
v.cam.distance = self.model.stat.extent * 0.5
|
||||
|
@@ -54,6 +54,6 @@ class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return np.concatenate([self.sim.data.qpos, self.sim.data.qvel]).ravel()
|
||||
|
||||
def viewer_setup(self):
|
||||
v = self.viewer
|
||||
v.cam.trackbodyid = 0
|
||||
v.cam.distance = self.model.stat.extent
|
||||
assert self.viewer is not None
|
||||
self.viewer.cam.trackbodyid = 0
|
||||
self.viewer.cam.distance = self.model.stat.extent
|
||||
|
@@ -130,6 +130,7 @@ class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return np.concatenate([self.data.qpos, self.data.qvel]).ravel()
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
v = self.viewer
|
||||
v.cam.trackbodyid = 0
|
||||
v.cam.distance = self.model.stat.extent
|
||||
|
@@ -47,6 +47,7 @@ class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
self.viewer.cam.trackbodyid = -1
|
||||
self.viewer.cam.distance = 4.0
|
||||
|
||||
|
@@ -164,6 +164,7 @@ class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
self.viewer.cam.trackbodyid = -1
|
||||
self.viewer.cam.distance = 4.0
|
||||
|
||||
|
@@ -43,6 +43,7 @@ class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
self.viewer.cam.trackbodyid = 0
|
||||
|
||||
def reset_model(self):
|
||||
|
@@ -150,6 +150,7 @@ class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
self.viewer.cam.trackbodyid = 0
|
||||
|
||||
def reset_model(self):
|
||||
|
@@ -119,6 +119,7 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return observation
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
getattr(self.viewer.cam, key)[:] = value
|
||||
|
@@ -225,6 +225,7 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return observation
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
getattr(self.viewer.cam, key)[:] = value
|
||||
|
@@ -60,6 +60,7 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return self._get_obs()
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
self.viewer.cam.trackbodyid = 2
|
||||
self.viewer.cam.distance = self.model.stat.extent * 0.5
|
||||
self.viewer.cam.lookat[2] = 1.15
|
||||
|
@@ -153,6 +153,7 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return observation
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
getattr(self.viewer.cam, key)[:] = value
|
||||
|
@@ -277,6 +277,7 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||
return observation
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
getattr(self.viewer.cam, key)[:] = value
|
||||
|
@@ -134,10 +134,12 @@ class Graph(Space):
|
||||
assert (
|
||||
num_edges >= 0
|
||||
), f"The number of edges is expected to be greater than 0, actual mask: {num_edges}"
|
||||
assert num_edges is not None
|
||||
|
||||
sampled_node_space = self._generate_sample_space(self.node_space, num_nodes)
|
||||
sampled_edge_space = self._generate_sample_space(self.edge_space, num_edges)
|
||||
|
||||
assert sampled_node_space is not None
|
||||
sampled_nodes = sampled_node_space.sample(node_space_mask)
|
||||
sampled_edges = (
|
||||
sampled_edge_space.sample(edge_space_mask)
|
||||
|
@@ -3,16 +3,23 @@ from collections import deque
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pygame
|
||||
from pygame import Surface
|
||||
from pygame.event import Event
|
||||
from pygame.locals import VIDEORESIZE
|
||||
|
||||
import gym.error
|
||||
from gym import Env, logger
|
||||
from gym.core import ActType, ObsType
|
||||
from gym.error import DependencyNotInstalled
|
||||
from gym.logger import deprecation
|
||||
|
||||
try:
|
||||
import pygame
|
||||
from pygame import Surface
|
||||
from pygame.event import Event
|
||||
from pygame.locals import VIDEORESIZE
|
||||
except ImportError:
|
||||
raise gym.error.DependencyNotInstalled(
|
||||
"Pygame is not installed, run `pip install gym[classic_control]`"
|
||||
)
|
||||
|
||||
try:
|
||||
import matplotlib
|
||||
|
||||
@@ -20,7 +27,7 @@ try:
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError:
|
||||
logger.warn("Matplotlib is not installed, run `pip install gym[other]`")
|
||||
matplotlib, plt = None, None
|
||||
plt = None
|
||||
|
||||
|
||||
class MissingKeysToAction(Exception):
|
||||
@@ -33,7 +40,7 @@ class PlayableGame:
|
||||
def __init__(
|
||||
self,
|
||||
env: Env,
|
||||
keys_to_action: Optional[Dict[Tuple[int], int]] = None,
|
||||
keys_to_action: Optional[Dict[Tuple[int, ...], int]] = None,
|
||||
zoom: Optional[float] = None,
|
||||
):
|
||||
"""Wraps an environment with a dictionary of keyboard buttons to action and if to zoom in on the environment.
|
||||
@@ -63,12 +70,14 @@ class PlayableGame:
|
||||
f"{self.env.spec.id} does not have explicit key to action mapping, "
|
||||
"please specify one manually"
|
||||
)
|
||||
assert isinstance(keys_to_action, dict)
|
||||
relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), []))
|
||||
return relevant_keys
|
||||
|
||||
def _get_video_size(self, zoom: Optional[float] = None) -> Tuple[int, int]:
|
||||
# TODO: this needs to be updated when the render API change goes through
|
||||
rendered = self.env.render(mode="rgb_array")
|
||||
assert rendered is not None and isinstance(rendered, np.ndarray)
|
||||
video_size = [rendered.shape[1], rendered.shape[0]]
|
||||
|
||||
if zoom is not None:
|
||||
@@ -211,9 +220,9 @@ def play(
|
||||
f"{env.spec.id} does not have explicit key to action mapping, "
|
||||
"please specify one manually"
|
||||
)
|
||||
assert keys_to_action is not None
|
||||
|
||||
key_code_to_action = {}
|
||||
|
||||
for key_combination, action in keys_to_action.items():
|
||||
key_code = tuple(
|
||||
sorted(ord(key) if isinstance(key, str) else key for key in key_combination)
|
||||
@@ -225,7 +234,7 @@ def play(
|
||||
if fps is None:
|
||||
fps = env.metadata.get("render_fps", 30)
|
||||
|
||||
done = True
|
||||
done, obs = True, None
|
||||
clock = pygame.time.Clock()
|
||||
|
||||
while game.running:
|
||||
@@ -316,7 +325,7 @@ class PlayPlot:
|
||||
for axis, name in zip(self.ax, plot_names):
|
||||
axis.set_title(name)
|
||||
self.t = 0
|
||||
self.cur_plot = [None for _ in range(num_plots)]
|
||||
self.cur_plot: List[Optional[plt.Axes]] = [None for _ in range(num_plots)]
|
||||
self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)]
|
||||
|
||||
def callback(
|
||||
@@ -352,4 +361,9 @@ class PlayPlot:
|
||||
range(xmin, xmax), list(self.data[i]), c="blue"
|
||||
)
|
||||
self.ax[i].set_xlim(xmin, xmax)
|
||||
|
||||
if plt is None:
|
||||
raise DependencyNotInstalled(
|
||||
"matplotlib is not installed, run `pip install gym[other]`"
|
||||
)
|
||||
plt.pause(0.000001)
|
||||
|
@@ -569,7 +569,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
|
||||
num_errors = self.num_envs - sum(successes)
|
||||
assert num_errors > 0
|
||||
for _ in range(num_errors):
|
||||
for i in range(num_errors):
|
||||
index, exctype, value = self.error_queue.get()
|
||||
logger.error(
|
||||
f"Received the following error from Worker-{index}: {exctype.__name__}: {value}"
|
||||
@@ -578,6 +578,7 @@ class AsyncVectorEnv(VectorEnv):
|
||||
self.parent_pipes[index].close()
|
||||
self.parent_pipes[index] = None
|
||||
|
||||
if i == num_errors - 1:
|
||||
logger.error("Raising the last exception back to the main process.")
|
||||
raise exctype(value)
|
||||
|
||||
|
@@ -1,9 +1,10 @@
|
||||
"""A synchronous vector environment."""
|
||||
from copy import deepcopy
|
||||
from typing import Any, Iterator, List, Optional, Sequence, Union
|
||||
from typing import Any, Callable, Iterator, List, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gym import Env
|
||||
from gym.spaces import Space
|
||||
from gym.vector.utils import concatenate, create_empty_array, iterate
|
||||
from gym.vector.vector_env import VectorEnv
|
||||
@@ -28,7 +29,7 @@ class SyncVectorEnv(VectorEnv):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: Iterator[callable],
|
||||
env_fns: Iterator[Callable[[], Env]],
|
||||
observation_space: Space = None,
|
||||
action_space: Space = None,
|
||||
copy: bool = True,
|
||||
|
@@ -6,7 +6,8 @@ from gym.vector.utils.shared_memory import (
|
||||
read_from_shared_memory,
|
||||
write_to_shared_memory,
|
||||
)
|
||||
from gym.vector.utils.spaces import BaseGymSpaces, _BaseGymSpaces, batch_space, iterate
|
||||
from gym.vector.utils.spaces import _BaseGymSpaces # pyright: reportPrivateUsage=false
|
||||
from gym.vector.utils.spaces import BaseGymSpaces, batch_space, iterate
|
||||
|
||||
__all__ = [
|
||||
"CloudpickleWrapper",
|
||||
@@ -17,7 +18,6 @@ __all__ = [
|
||||
"read_from_shared_memory",
|
||||
"write_to_shared_memory",
|
||||
"BaseGymSpaces",
|
||||
"_BaseGymSpaces",
|
||||
"batch_space",
|
||||
"iterate",
|
||||
]
|
||||
|
@@ -85,6 +85,7 @@ class VectorEnv(gym.Env):
|
||||
Raises:
|
||||
NotImplementedError: VectorEnv does not implement function
|
||||
"""
|
||||
raise NotImplementedError("VectorEnv does not implement function")
|
||||
|
||||
def reset(
|
||||
self,
|
||||
|
@@ -2,13 +2,14 @@
|
||||
import numpy as np
|
||||
|
||||
import gym
|
||||
from gym.error import DependencyNotInstalled
|
||||
from gym.spaces import Box
|
||||
|
||||
try:
|
||||
import cv2
|
||||
except ImportError:
|
||||
cv2 = None
|
||||
raise gym.error.DependencyNotInstalled(
|
||||
"opencv-python package not installed, run `pip install gym[other]` to get dependencies for atari"
|
||||
)
|
||||
|
||||
|
||||
class AtariPreprocessing(gym.Wrapper):
|
||||
@@ -60,10 +61,6 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
ValueError: Disable frame-skipping in the original env
|
||||
"""
|
||||
super().__init__(env)
|
||||
if cv2 is None:
|
||||
raise DependencyNotInstalled(
|
||||
"opencv-python package not installed, run `pip install gym[other]` to get dependencies for atari"
|
||||
)
|
||||
assert frame_skip > 0
|
||||
assert screen_size > 0
|
||||
assert noop_max >= 0
|
||||
@@ -87,6 +84,7 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
self.scale_obs = scale_obs
|
||||
|
||||
# buffer of most recent two observations for max pooling
|
||||
assert isinstance(env.observation_space, Box)
|
||||
if grayscale_obs:
|
||||
self.obs_buffer = [
|
||||
np.empty(env.observation_space.shape[:2], dtype=np.uint8),
|
||||
@@ -114,7 +112,7 @@ class AtariPreprocessing(gym.Wrapper):
|
||||
|
||||
def step(self, action):
|
||||
"""Applies the preprocessing for an :meth:`env.step`."""
|
||||
total_reward = 0.0
|
||||
total_reward, done, info = 0.0, False, {}
|
||||
|
||||
for t in range(self.frame_skip):
|
||||
_, reward, done, info = self.env.step(action)
|
||||
|
@@ -32,8 +32,9 @@ class GrayScaleObservation(gym.ObservationWrapper):
|
||||
self.keep_dim = keep_dim
|
||||
|
||||
assert (
|
||||
len(env.observation_space.shape) == 3
|
||||
and env.observation_space.shape[-1] == 3
|
||||
isinstance(self.observation_space, Box)
|
||||
and len(self.observation_space.shape) == 3
|
||||
and self.observation_space.shape[-1] == 3
|
||||
)
|
||||
|
||||
obs_shape = self.observation_space.shape[:2]
|
||||
|
@@ -88,13 +88,16 @@ class HumanRendering(gym.Wrapper):
|
||||
"pygame is not installed, run `pip install gym[box2d]`"
|
||||
)
|
||||
if self.env.render_mode == "rgb_array":
|
||||
last_rgb_array = self.env.render(**kwargs)[-1]
|
||||
last_rgb_array = self.env.render(**kwargs)
|
||||
assert isinstance(last_rgb_array, list)
|
||||
last_rgb_array = last_rgb_array[-1]
|
||||
elif self.env.render_mode == "single_rgb_array":
|
||||
last_rgb_array = self.env.render(**kwargs)
|
||||
else:
|
||||
raise Exception(
|
||||
"Wrapped environment must have mode 'rgb_array' or 'single_rgb_array'"
|
||||
f"Wrapped environment must have mode 'rgb_array' or 'single_rgb_array', actual render mode: {self.env.render_mode}"
|
||||
)
|
||||
assert isinstance(last_rgb_array, np.ndarray)
|
||||
|
||||
if mode == "human":
|
||||
rgb_array = np.transpose(last_rgb_array, axes=(1, 0, 2))
|
||||
|
@@ -148,7 +148,9 @@ class VideoRecorder: # TODO: remove with gym 1.0
|
||||
)
|
||||
self.output_frames_per_sec = self.backward_compatible_output_frames_per_sec
|
||||
|
||||
self.encoder = None # lazily start the process
|
||||
self.encoder: Optional[
|
||||
Union[TextEncoder, ImageEncoder]
|
||||
] = None # lazily start the process
|
||||
self.broken = False
|
||||
|
||||
# Dump metadata
|
||||
@@ -387,7 +389,7 @@ class ImageEncoder:
|
||||
InvalidFrame: Expects frame to have shape (w,h,3) or (w,h,4)
|
||||
DependencyNotInstalled: Found neither the ffmpeg nor avconv executables.
|
||||
"""
|
||||
self.proc = None
|
||||
self.proc: Optional[subprocess.Popen] = None
|
||||
self.output_path = output_path
|
||||
# Frame shape should be lines-first, so w and h are swapped
|
||||
h, w, pixfmt = frame_shape
|
||||
@@ -488,6 +490,7 @@ class ImageEncoder:
|
||||
f"Your frame has data type {frame.dtype}, but we require uint8 (i.e. RGB values from 0-255)."
|
||||
)
|
||||
|
||||
assert self.proc is not None and self.proc.stdin is not None
|
||||
try:
|
||||
self.proc.stdin.write(frame.tobytes())
|
||||
except Exception:
|
||||
@@ -496,6 +499,7 @@ class ImageEncoder:
|
||||
|
||||
def close(self):
|
||||
"""Closes the Image encoder."""
|
||||
assert self.proc is not None and self.proc.stdin is not None
|
||||
self.proc.stdin.close()
|
||||
ret = self.proc.wait()
|
||||
if ret != 0:
|
||||
|
@@ -81,19 +81,20 @@ class NormalizeObservation(gym.core.Wrapper):
|
||||
|
||||
def reset(self, **kwargs):
|
||||
"""Resets the environment and normalizes the observation."""
|
||||
return_info = kwargs.get("return_info", False)
|
||||
if return_info:
|
||||
if kwargs.get("return_info", False):
|
||||
obs, info = self.env.reset(**kwargs)
|
||||
|
||||
if self.is_vector_env:
|
||||
return self.normalize(obs), info
|
||||
else:
|
||||
return self.normalize(np.array([obs]))[0], info
|
||||
else:
|
||||
obs = self.env.reset(**kwargs)
|
||||
|
||||
if self.is_vector_env:
|
||||
obs = self.normalize(obs)
|
||||
return self.normalize(obs)
|
||||
else:
|
||||
obs = self.normalize(np.array([obs]))[0]
|
||||
if not return_info:
|
||||
return obs
|
||||
else:
|
||||
return obs, info
|
||||
return self.normalize(np.array([obs]))[0]
|
||||
|
||||
def normalize(self, obs):
|
||||
"""Normalises the observation using the running mean and variance of the observations."""
|
||||
|
@@ -49,3 +49,8 @@ class OrderEnforcing(gym.Wrapper):
|
||||
"set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper."
|
||||
)
|
||||
return self.env.render(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def has_reset(self):
|
||||
"""Returns if the environment has been reset before."""
|
||||
return self._has_reset
|
||||
|
@@ -120,8 +120,7 @@ class PixelObservationWrapper(gym.ObservationWrapper):
|
||||
elif self._observation_is_dict:
|
||||
self.observation_space = copy.deepcopy(wrapped_observation_space)
|
||||
else:
|
||||
self.observation_space = spaces.Dict()
|
||||
self.observation_space.spaces[STATE_KEY] = wrapped_observation_space
|
||||
self.observation_space = spaces.Dict({STATE_KEY: wrapped_observation_space})
|
||||
|
||||
# Extend observation space with pixels.
|
||||
|
||||
@@ -129,7 +128,7 @@ class PixelObservationWrapper(gym.ObservationWrapper):
|
||||
pixels_spaces = {}
|
||||
for pixel_key in pixel_keys:
|
||||
pixels = self.env.render(**render_kwargs[pixel_key])
|
||||
pixels = pixels[-1] if isinstance(pixels, List) else pixels
|
||||
pixels: np.ndarray = pixels[-1] if isinstance(pixels, List) else pixels
|
||||
|
||||
if np.issubdtype(pixels.dtype, np.integer):
|
||||
low, high = (0, 255)
|
||||
|
@@ -1,6 +1,7 @@
|
||||
"""Wrapper that tracks the cumulative rewards and episode lengths."""
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -86,8 +87,8 @@ class RecordEpisodeStatistics(gym.Wrapper):
|
||||
self.num_envs = getattr(env, "num_envs", 1)
|
||||
self.t0 = time.perf_counter()
|
||||
self.episode_count = 0
|
||||
self.episode_returns = None
|
||||
self.episode_lengths = None
|
||||
self.episode_returns: Optional[np.ndarray] = None
|
||||
self.episode_lengths: Optional[np.ndarray] = None
|
||||
self.return_queue = deque(maxlen=deque_size)
|
||||
self.length_queue = deque(maxlen=deque_size)
|
||||
self.is_vector_env = getattr(env, "is_vector_env", False)
|
||||
|
@@ -1,6 +1,6 @@
|
||||
"""Wrapper for recording videos."""
|
||||
import os
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional
|
||||
|
||||
import gym
|
||||
from gym import logger
|
||||
@@ -77,7 +77,7 @@ class RecordVideo(gym.Wrapper): # TODO: remove with gym 1.0
|
||||
|
||||
self.episode_trigger = episode_trigger
|
||||
self.step_trigger = step_trigger
|
||||
self.video_recorder = None
|
||||
self.video_recorder: Optional[video_recorder.VideoRecorder] = None
|
||||
|
||||
self.video_folder = os.path.abspath(video_folder)
|
||||
# Create output folder if needed
|
||||
@@ -101,6 +101,7 @@ class RecordVideo(gym.Wrapper): # TODO: remove with gym 1.0
|
||||
"""Reset the environment using kwargs and then starts recording if video enabled."""
|
||||
observations = super().reset(**kwargs)
|
||||
if self.recording:
|
||||
assert self.video_recorder is not None
|
||||
self.video_recorder.capture_frame()
|
||||
self.recorded_frames += 1
|
||||
if self.video_length > 0:
|
||||
@@ -148,6 +149,7 @@ class RecordVideo(gym.Wrapper): # TODO: remove with gym 1.0
|
||||
self.episode_id += 1
|
||||
|
||||
if self.recording:
|
||||
assert self.video_recorder is not None
|
||||
self.video_recorder.capture_frame()
|
||||
self.recorded_frames += 1
|
||||
if self.video_length > 0:
|
||||
@@ -168,6 +170,7 @@ class RecordVideo(gym.Wrapper): # TODO: remove with gym 1.0
|
||||
def close_video_recorder(self):
|
||||
"""Closes the video recorder if currently recording."""
|
||||
if self.recording:
|
||||
assert self.video_recorder is not None
|
||||
self.video_recorder.close()
|
||||
self.recording = False
|
||||
self.recorded_frames = 1
|
||||
|
@@ -39,7 +39,10 @@ class ResizeObservation(gym.ObservationWrapper):
|
||||
|
||||
self.shape = tuple(shape)
|
||||
|
||||
obs_shape = self.shape + self.observation_space.shape[2:]
|
||||
assert isinstance(
|
||||
env.observation_space, Box
|
||||
), f"Expected the observation space to be Box, actual type: {type(env.observation_space)}"
|
||||
obs_shape = self.shape + env.observation_space.shape[2:]
|
||||
self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)
|
||||
|
||||
def observation(self, observation):
|
||||
|
@@ -9,40 +9,7 @@ exclude = [
|
||||
"**/node_modules",
|
||||
"**/__pycache__",
|
||||
|
||||
"gym/envs/box2d/bipedal_walker.py",
|
||||
"gym/envs/box2d/car_racing.py",
|
||||
|
||||
"gym/spaces/graph.py",
|
||||
|
||||
"gym/envs/mujoco/**",
|
||||
"gym/utils/play.py",
|
||||
|
||||
"gym/vector/async_vector_env.py",
|
||||
"gym/vector/utils/__init__.py",
|
||||
|
||||
"gym/wrappers/atari_preprocessing.py",
|
||||
"gym/wrappers/gray_scale_observation.py",
|
||||
"gym/wrappers/human_rendering.py",
|
||||
"gym/wrappers/normalize.py",
|
||||
"gym/wrappers/pixel_observation.py",
|
||||
"gym/wrappers/record_video.py",
|
||||
"gym/wrappers/monitoring/video_recorder.py",
|
||||
"gym/wrappers/resize_observation.py",
|
||||
|
||||
"tests/envs/test_env_implementation.py",
|
||||
"tests/utils/test_play.py",
|
||||
"tests/vector/test_async_vector_env.py",
|
||||
"tests/vector/test_shared_memory.py",
|
||||
"tests/vector/test_spaces.py",
|
||||
"tests/vector/test_sync_vector_env.py",
|
||||
"tests/vector/test_vector_env.py",
|
||||
"tests/wrappers/test_gray_scale_observation.py",
|
||||
"tests/wrappers/test_order_enforcing.py",
|
||||
"tests/wrappers/test_record_episode_statistics.py",
|
||||
"tests/wrappers/test_resize_observation.py",
|
||||
"tests/wrappers/test_time_aware_observation.py",
|
||||
"tests/wrappers/test_video_recorder.py",
|
||||
|
||||
"gym/envs/mujoco/mujoco_env.py",
|
||||
]
|
||||
|
||||
strict = [
|
||||
@@ -51,6 +18,7 @@ strict = [
|
||||
|
||||
typeCheckingMode = "basic"
|
||||
pythonVersion = "3.6"
|
||||
pythonPlatform = "All"
|
||||
typeshedPath = "typeshed"
|
||||
enableTypeIgnoreComments = true
|
||||
|
||||
|
@@ -1,5 +1,5 @@
|
||||
"""Finds all the specs that we can test with"""
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -18,22 +18,27 @@ def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]:
|
||||
|
||||
|
||||
# Tries to make all environment to test with
|
||||
all_testing_initialised_envs = list(
|
||||
filter(None, [try_make_env(env_spec) for env_spec in gym.envs.registry.values()])
|
||||
)
|
||||
all_testing_initialised_envs: List[Optional[gym.Env]] = [
|
||||
try_make_env(env_spec) for env_spec in gym.envs.registry.values()
|
||||
]
|
||||
all_testing_initialised_envs: List[gym.Env] = [
|
||||
env for env in all_testing_initialised_envs if env is not None
|
||||
]
|
||||
|
||||
# All testing, mujoco and gym environment specs
|
||||
all_testing_env_specs = [env.spec for env in all_testing_initialised_envs]
|
||||
mujoco_testing_env_specs = [
|
||||
all_testing_env_specs: List[EnvSpec] = [
|
||||
env.spec for env in all_testing_initialised_envs
|
||||
]
|
||||
mujoco_testing_env_specs: List[EnvSpec] = [
|
||||
env_spec
|
||||
for env_spec in all_testing_env_specs
|
||||
if "gym.envs.mujoco" in env_spec.entry_point
|
||||
]
|
||||
gym_testing_env_specs = [
|
||||
gym_testing_env_specs: List[EnvSpec] = [
|
||||
env_spec
|
||||
for env_spec in all_testing_env_specs
|
||||
if any(
|
||||
f"gym.{ep}" in env_spec.entry_point
|
||||
f"gym.envs.{ep}" in env_spec.entry_point
|
||||
for ep in ["box2d", "classic_control", "toy_text"]
|
||||
)
|
||||
]
|
||||
|
@@ -194,7 +194,7 @@ def test_play_loop_real_env():
|
||||
|
||||
# first action is 0 because at the first iteration
|
||||
# we can not inject a callback event into play()
|
||||
env.step(0)
|
||||
obs, _, _, _ = env.step(0)
|
||||
for e in keydown_events:
|
||||
action = keys_to_action[chr(e.key) if str_keys else (e.key,)]
|
||||
obs, _, _, _ = env.step(action)
|
||||
|
@@ -17,21 +17,19 @@ from tests.vector.utils import (
|
||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||
def test_create_async_vector_env(shared_memory):
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
assert env.num_envs == 8
|
||||
env.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||
def test_reset_async_vector_env(shared_memory):
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||
try:
|
||||
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
observations = env.reset()
|
||||
finally:
|
||||
|
||||
env.close()
|
||||
|
||||
assert isinstance(env.observation_space, Box)
|
||||
@@ -71,7 +69,7 @@ def test_reset_async_vector_env(shared_memory):
|
||||
@pytest.mark.parametrize("use_single_action_space", [True, False])
|
||||
def test_step_async_vector_env(shared_memory, use_single_action_space):
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||
try:
|
||||
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
observations = env.reset()
|
||||
|
||||
@@ -83,7 +81,7 @@ def test_step_async_vector_env(shared_memory, use_single_action_space):
|
||||
else:
|
||||
actions = env.action_space.sample()
|
||||
observations, rewards, dones, _ = env.step(actions)
|
||||
finally:
|
||||
|
||||
env.close()
|
||||
|
||||
assert isinstance(env.observation_space, Box)
|
||||
@@ -106,12 +104,12 @@ def test_step_async_vector_env(shared_memory, use_single_action_space):
|
||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||
def test_call_async_vector_env(shared_memory):
|
||||
env_fns = [make_env("CartPole-v1", i, render_mode="rgb_array") for i in range(4)]
|
||||
try:
|
||||
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
_ = env.reset()
|
||||
images = env.call("render")
|
||||
gravity = env.call("gravity")
|
||||
finally:
|
||||
|
||||
env.close()
|
||||
|
||||
assert isinstance(images, tuple)
|
||||
@@ -130,59 +128,60 @@ def test_call_async_vector_env(shared_memory):
|
||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||
def test_set_attr_async_vector_env(shared_memory):
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
|
||||
try:
|
||||
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
env.set_attr("gravity", [9.81, 3.72, 8.87, 1.62])
|
||||
gravity = env.get_attr("gravity")
|
||||
assert gravity == (9.81, 3.72, 8.87, 1.62)
|
||||
finally:
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||
def test_copy_async_vector_env(shared_memory):
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||
try:
|
||||
|
||||
# TODO, these tests do nothing, understand the purpose of the tests and fix them
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=True)
|
||||
observations = env.reset()
|
||||
observations[0] = 0
|
||||
finally:
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||
def test_no_copy_async_vector_env(shared_memory):
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||
try:
|
||||
|
||||
# TODO, these tests do nothing, understand the purpose of the tests and fix them
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=False)
|
||||
observations = env.reset()
|
||||
observations[0] = 0
|
||||
finally:
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||
def test_reset_timeout_async_vector_env(shared_memory):
|
||||
env_fns = [make_slow_env(0.3, i) for i in range(4)]
|
||||
with pytest.raises(TimeoutError):
|
||||
try:
|
||||
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
with pytest.raises(TimeoutError):
|
||||
env.reset_async()
|
||||
env.reset_wait(timeout=0.1)
|
||||
finally:
|
||||
|
||||
env.close(terminate=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||
def test_step_timeout_async_vector_env(shared_memory):
|
||||
env_fns = [make_slow_env(0.0, i) for i in range(4)]
|
||||
with pytest.raises(TimeoutError):
|
||||
try:
|
||||
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
with pytest.raises(TimeoutError):
|
||||
env.reset()
|
||||
env.step_async([0.1, 0.1, 0.3, 0.1])
|
||||
env.step_async(np.array([0.1, 0.1, 0.3, 0.1]))
|
||||
observations, rewards, dones, _ = env.step_wait(timeout=0.1)
|
||||
finally:
|
||||
env.close(terminate=True)
|
||||
|
||||
|
||||
@@ -190,19 +189,20 @@ def test_step_timeout_async_vector_env(shared_memory):
|
||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||
def test_reset_out_of_order_async_vector_env(shared_memory):
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
|
||||
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
with pytest.raises(NoAsyncCallError):
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
env.reset_wait()
|
||||
except NoAsyncCallError as exception:
|
||||
assert exception.name == "reset"
|
||||
raise
|
||||
finally:
|
||||
|
||||
env.close(terminate=True)
|
||||
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
with pytest.raises(AlreadyPendingCallError):
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
actions = env.action_space.sample()
|
||||
env.reset()
|
||||
env.step_async(actions)
|
||||
@@ -210,7 +210,7 @@ def test_reset_out_of_order_async_vector_env(shared_memory):
|
||||
except NoAsyncCallError as exception:
|
||||
assert exception.name == "step"
|
||||
raise
|
||||
finally:
|
||||
|
||||
env.close(terminate=True)
|
||||
|
||||
|
||||
@@ -218,28 +218,29 @@ def test_reset_out_of_order_async_vector_env(shared_memory):
|
||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||
def test_step_out_of_order_async_vector_env(shared_memory):
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
|
||||
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
with pytest.raises(NoAsyncCallError):
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
actions = env.action_space.sample()
|
||||
observations = env.reset()
|
||||
observations, rewards, dones, infos = env.step_wait()
|
||||
env.action_space.sample()
|
||||
env.reset()
|
||||
env.step_wait()
|
||||
except AlreadyPendingCallError as exception:
|
||||
assert exception.name == "step"
|
||||
raise
|
||||
finally:
|
||||
|
||||
env.close(terminate=True)
|
||||
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
with pytest.raises(AlreadyPendingCallError):
|
||||
try:
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
actions = env.action_space.sample()
|
||||
env.reset_async()
|
||||
env.step_async(actions)
|
||||
except AlreadyPendingCallError as exception:
|
||||
assert exception.name == "reset"
|
||||
raise
|
||||
finally:
|
||||
|
||||
env.close(terminate=True)
|
||||
|
||||
|
||||
@@ -265,7 +266,7 @@ def test_check_spaces_async_vector_env(shared_memory):
|
||||
|
||||
def test_custom_space_async_vector_env():
|
||||
env_fns = [make_custom_space_env(i) for i in range(4)]
|
||||
try:
|
||||
|
||||
env = AsyncVectorEnv(env_fns, shared_memory=False)
|
||||
reset_observations = env.reset()
|
||||
|
||||
@@ -274,7 +275,7 @@ def test_custom_space_async_vector_env():
|
||||
|
||||
actions = ("action-2", "action-3", "action-5", "action-7")
|
||||
step_observations, rewards, dones, _ = env.step(actions)
|
||||
finally:
|
||||
|
||||
env.close()
|
||||
|
||||
assert isinstance(env.single_observation_space, CustomSpace)
|
||||
|
@@ -116,6 +116,7 @@ def test_batch_space_custom_space(space, expected_batch_space_4):
|
||||
def test_iterate(space, batch_space):
|
||||
items = batch_space.sample()
|
||||
iterator = iterate(batch_space, items)
|
||||
i = 0
|
||||
for i, item in enumerate(iterator):
|
||||
assert item in space
|
||||
assert i == 3
|
||||
@@ -129,6 +130,7 @@ def test_iterate(space, batch_space):
|
||||
def test_iterate_custom_space(space, batch_space):
|
||||
items = batch_space.sample()
|
||||
iterator = iterate(batch_space, items)
|
||||
i = 0
|
||||
for i, item in enumerate(iterator):
|
||||
assert item in space
|
||||
assert i == 3
|
||||
|
@@ -15,9 +15,7 @@ from tests.vector.utils import (
|
||||
|
||||
def test_create_sync_vector_env():
|
||||
env_fns = [make_env("FrozenLake-v1", i) for i in range(8)]
|
||||
try:
|
||||
env = SyncVectorEnv(env_fns)
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
assert env.num_envs == 8
|
||||
@@ -25,10 +23,8 @@ def test_create_sync_vector_env():
|
||||
|
||||
def test_reset_sync_vector_env():
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||
try:
|
||||
env = SyncVectorEnv(env_fns)
|
||||
observations = env.reset()
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
assert isinstance(env.observation_space, Box)
|
||||
@@ -39,10 +35,8 @@ def test_reset_sync_vector_env():
|
||||
|
||||
del observations
|
||||
|
||||
try:
|
||||
env = SyncVectorEnv(env_fns)
|
||||
observations = env.reset(return_info=False)
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
assert isinstance(env.observation_space, Box)
|
||||
@@ -54,10 +48,9 @@ def test_reset_sync_vector_env():
|
||||
del observations
|
||||
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||
try:
|
||||
|
||||
env = SyncVectorEnv(env_fns)
|
||||
observations, infos = env.reset(return_info=True)
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
assert isinstance(env.observation_space, Box)
|
||||
@@ -72,7 +65,7 @@ def test_reset_sync_vector_env():
|
||||
@pytest.mark.parametrize("use_single_action_space", [True, False])
|
||||
def test_step_sync_vector_env(use_single_action_space):
|
||||
env_fns = [make_env("FrozenLake-v1", i) for i in range(8)]
|
||||
try:
|
||||
|
||||
env = SyncVectorEnv(env_fns)
|
||||
observations = env.reset()
|
||||
|
||||
@@ -84,7 +77,7 @@ def test_step_sync_vector_env(use_single_action_space):
|
||||
else:
|
||||
actions = env.action_space.sample()
|
||||
observations, rewards, dones, _ = env.step(actions)
|
||||
finally:
|
||||
|
||||
env.close()
|
||||
|
||||
assert isinstance(env.observation_space, MultiDiscrete)
|
||||
@@ -106,12 +99,12 @@ def test_step_sync_vector_env(use_single_action_space):
|
||||
|
||||
def test_call_sync_vector_env():
|
||||
env_fns = [make_env("CartPole-v1", i, render_mode="rgb_array") for i in range(4)]
|
||||
try:
|
||||
|
||||
env = SyncVectorEnv(env_fns)
|
||||
_ = env.reset()
|
||||
images = env.call("render")
|
||||
gravity = env.call("gravity")
|
||||
finally:
|
||||
|
||||
env.close()
|
||||
|
||||
assert isinstance(images, tuple)
|
||||
@@ -129,12 +122,12 @@ def test_call_sync_vector_env():
|
||||
|
||||
def test_set_attr_sync_vector_env():
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
|
||||
try:
|
||||
|
||||
env = SyncVectorEnv(env_fns)
|
||||
env.set_attr("gravity", [9.81, 3.72, 8.87, 1.62])
|
||||
gravity = env.get_attr("gravity")
|
||||
assert gravity == (9.81, 3.72, 8.87, 1.62)
|
||||
finally:
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
@@ -150,7 +143,7 @@ def test_check_spaces_sync_vector_env():
|
||||
|
||||
def test_custom_space_sync_vector_env():
|
||||
env_fns = [make_custom_space_env(i) for i in range(4)]
|
||||
try:
|
||||
|
||||
env = SyncVectorEnv(env_fns)
|
||||
reset_observations = env.reset()
|
||||
|
||||
@@ -159,7 +152,7 @@ def test_custom_space_sync_vector_env():
|
||||
|
||||
actions = ("action-2", "action-3", "action-5", "action-7")
|
||||
step_observations, rewards, dones, _ = env.step(actions)
|
||||
finally:
|
||||
|
||||
env.close()
|
||||
|
||||
assert isinstance(env.single_observation_space, CustomSpace)
|
||||
|
@@ -12,7 +12,7 @@ from tests.vector.utils import CustomSpace, make_env
|
||||
def test_vector_env_equal(shared_memory):
|
||||
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
|
||||
num_steps = 100
|
||||
try:
|
||||
|
||||
async_env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||
sync_env = SyncVectorEnv(env_fns)
|
||||
|
||||
@@ -45,7 +45,6 @@ def test_vector_env_equal(shared_memory):
|
||||
assert np.all(async_rewards == sync_rewards)
|
||||
assert np.all(async_dones == sync_dones)
|
||||
|
||||
finally:
|
||||
async_env.close()
|
||||
sync_env.close()
|
||||
|
||||
|
@@ -2,6 +2,7 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
import gym
|
||||
from gym import spaces
|
||||
from gym.wrappers import AtariPreprocessing, GrayScaleObservation
|
||||
|
||||
pytest.importorskip("gym.envs.atari")
|
||||
@@ -20,8 +21,12 @@ def test_gray_scale_observation(env_id, keep_dim):
|
||||
gym.make(env_id, disable_env_checker=True), screen_size=84, grayscale_obs=False
|
||||
)
|
||||
wrapped_env = GrayScaleObservation(rgb_env, keep_dim=keep_dim)
|
||||
|
||||
assert isinstance(rgb_env.observation_space, spaces.Box)
|
||||
assert rgb_env.observation_space.shape[-1] == 3
|
||||
|
||||
assert isinstance(wrapped_env.observation_space, spaces.Box)
|
||||
|
||||
seed = 0
|
||||
|
||||
gray_obs = gray_env.reset(seed=seed)
|
||||
|
@@ -26,16 +26,16 @@ def test_order_enforcing():
|
||||
|
||||
# Assert that the order enforcing works for step and render before reset
|
||||
order_enforced_env = OrderEnforcing(env)
|
||||
assert order_enforced_env._has_reset is False
|
||||
assert order_enforced_env.has_reset is False
|
||||
with pytest.raises(ResetNeeded):
|
||||
order_enforced_env.step(0)
|
||||
with pytest.raises(ResetNeeded):
|
||||
order_enforced_env.render(mode="rgb_array")
|
||||
assert order_enforced_env._has_reset is False
|
||||
assert order_enforced_env.has_reset is False
|
||||
|
||||
# Assert that the Assertion errors are not raised after reset
|
||||
order_enforced_env.reset()
|
||||
assert order_enforced_env._has_reset is True
|
||||
assert order_enforced_env.has_reset is True
|
||||
order_enforced_env.step(0)
|
||||
order_enforced_env.render(mode="rgb_array")
|
||||
|
||||
|
@@ -14,6 +14,7 @@ def test_record_episode_statistics(env_id, deque_size):
|
||||
|
||||
for n in range(5):
|
||||
env.reset()
|
||||
assert env.episode_returns is not None and env.episode_lengths is not None
|
||||
assert env.episode_returns[0] == 0.0
|
||||
assert env.episode_lengths[0] == 0
|
||||
for t in range(env.spec.max_episode_steps):
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
|
||||
import gym
|
||||
from gym import spaces
|
||||
from gym.wrappers import ResizeObservation
|
||||
|
||||
pytest.importorskip("gym.envs.atari")
|
||||
@@ -14,6 +15,7 @@ def test_resize_observation(env_id, shape):
|
||||
env = gym.make(env_id, disable_env_checker=True)
|
||||
env = ResizeObservation(env, shape)
|
||||
|
||||
assert isinstance(env.observation_space, spaces.Box)
|
||||
assert env.observation_space.shape[-1] == 3
|
||||
obs = env.reset()
|
||||
if isinstance(shape, int):
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
|
||||
import gym
|
||||
from gym import spaces
|
||||
from gym.wrappers import TimeAwareObservation
|
||||
|
||||
|
||||
@@ -9,6 +10,8 @@ def test_time_aware_observation(env_id):
|
||||
env = gym.make(env_id, disable_env_checker=True)
|
||||
wrapped_env = TimeAwareObservation(env)
|
||||
|
||||
assert isinstance(env.observation_space, spaces.Box)
|
||||
assert isinstance(wrapped_env.observation_space, spaces.Box)
|
||||
assert wrapped_env.observation_space.shape[0] == env.observation_space.shape[0] + 1
|
||||
|
||||
obs = env.reset()
|
||||
|
@@ -33,9 +33,10 @@ def test_record_simple():
|
||||
rec = VideoRecorder(env)
|
||||
env.reset()
|
||||
rec.capture_frame()
|
||||
assert rec.encoder is not None
|
||||
proc = rec.encoder.proc
|
||||
|
||||
assert proc.poll() is None # subprocess is running
|
||||
assert proc is not None and proc.poll() is None # subprocess is running
|
||||
|
||||
rec.close()
|
||||
|
||||
@@ -55,9 +56,10 @@ def test_autoclose():
|
||||
rec.capture_frame()
|
||||
|
||||
rec_path = rec.path
|
||||
assert rec.encoder is not None
|
||||
proc = rec.encoder.proc
|
||||
|
||||
assert proc.poll() is None # subprocess is running
|
||||
assert proc is not None and proc.poll() is None # subprocess is running
|
||||
|
||||
# The function ends without an explicit `rec.close()` call
|
||||
# The Python interpreter will implicitly do `del rec` on garbage cleaning
|
||||
@@ -68,7 +70,7 @@ def test_autoclose():
|
||||
gc.collect() # do explicit garbage collection for test
|
||||
time.sleep(5) # wait for subprocess exiting
|
||||
|
||||
assert proc.poll() is not None # subprocess is terminated
|
||||
assert proc is not None and proc.poll() is not None # subprocess is terminated
|
||||
assert os.path.exists(rec_path)
|
||||
f = open(rec_path)
|
||||
assert os.fstat(f.fileno()).st_size > 100
|
||||
|
Reference in New Issue
Block a user