mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-30 17:57:30 +00:00
Full type hinting (#2942)
* Allows a new RNG to be generated with seed=-1 and updated env_checker to fix bug if environment doesn't use np_random in reset
* Revert "fixed `gym.vector.make` where the checker was being applied in the opposite case than was intended to (#2871)"
This reverts commit 519dfd9117
.
* Remove bad pushed commits
* Fixed spelling in core.py
* Pins pytest to the last py 3.6 version
* Allow Box automatic scalar shape
* Add test box and change default from () to (1,)
* update Box shape inference with more strict checking
* Update the box shape and add check on the custom Box shape
* Removed incorrect shape type and assert shape code
* Update the Box and associated tests
* Remove all folders and files from pyright exclude
* Revert issues
* Push RedTachyon code review
* Add Python Platform
* Remove play from pyright check
* Fixed CI issues
* remove mujoco env type hinting
* Fixed pixel observation test
* Added some new type hints
* Fixed CI errors
* Fixed CI errors
* Remove play.py from exlucde pyright
* Fixed pyright issues
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
"""Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper."""
|
"""Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper."""
|
||||||
import sys
|
import sys
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
Dict,
|
||||||
Generic,
|
Generic,
|
||||||
@@ -17,6 +18,9 @@ from gym.logger import deprecation, warn
|
|||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
from gym.utils.seeding import RandomNumberGenerator
|
from gym.utils.seeding import RandomNumberGenerator
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from gym.envs.registration import EnvSpec
|
||||||
|
|
||||||
if sys.version_info[0:2] == (3, 6):
|
if sys.version_info[0:2] == (3, 6):
|
||||||
warn(
|
warn(
|
||||||
"Gym minimally supports python 3.6 as the python foundation not longer supports the version, please update your version to 3.7+"
|
"Gym minimally supports python 3.6 as the python foundation not longer supports the version, please update your version to 3.7+"
|
||||||
@@ -106,7 +110,7 @@ class Env(Generic[ObsType, ActType], metaclass=decorator):
|
|||||||
metadata = {"render_modes": []}
|
metadata = {"render_modes": []}
|
||||||
render_mode = None # define render_mode if your environment supports rendering
|
render_mode = None # define render_mode if your environment supports rendering
|
||||||
reward_range = (-float("inf"), float("inf"))
|
reward_range = (-float("inf"), float("inf"))
|
||||||
spec = None
|
spec: "EnvSpec" = None
|
||||||
|
|
||||||
# Set these in ALL subclasses
|
# Set these in ALL subclasses
|
||||||
action_space: spaces.Space[ActType]
|
action_space: spaces.Space[ActType]
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
__credits__ = ["Andrea PIERRÉ"]
|
__credits__ = ["Andrea PIERRÉ"]
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -25,6 +25,9 @@ except ImportError:
|
|||||||
raise DependencyNotInstalled("box2D is not installed, run `pip install gym[box2d]`")
|
raise DependencyNotInstalled("box2D is not installed, run `pip install gym[box2d]`")
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import pygame
|
||||||
|
|
||||||
FPS = 50
|
FPS = 50
|
||||||
SCALE = 30.0 # affects how fast-paced the game is, forces should be adjusted as well
|
SCALE = 30.0 # affects how fast-paced the game is, forces should be adjusted as well
|
||||||
|
|
||||||
@@ -170,8 +173,8 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
self.isopen = True
|
self.isopen = True
|
||||||
|
|
||||||
self.world = Box2D.b2World()
|
self.world = Box2D.b2World()
|
||||||
self.terrain = None
|
self.terrain: List[Box2D.b2Body] = []
|
||||||
self.hull = None
|
self.hull: Optional[Box2D.b2Body] = None
|
||||||
|
|
||||||
self.prev_shaping = None
|
self.prev_shaping = None
|
||||||
|
|
||||||
@@ -256,7 +259,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
|
|
||||||
self.render_mode = render_mode
|
self.render_mode = render_mode
|
||||||
self.renderer = Renderer(self.render_mode, self._render)
|
self.renderer = Renderer(self.render_mode, self._render)
|
||||||
self.screen = None
|
self.screen: Optional[pygame.Surface] = None
|
||||||
self.clock = None
|
self.clock = None
|
||||||
|
|
||||||
def _destroy(self):
|
def _destroy(self):
|
||||||
@@ -283,6 +286,9 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
self.terrain = []
|
self.terrain = []
|
||||||
self.terrain_x = []
|
self.terrain_x = []
|
||||||
self.terrain_y = []
|
self.terrain_y = []
|
||||||
|
|
||||||
|
stair_steps, stair_width, stair_height = 0, 0, 0
|
||||||
|
original_y = 0
|
||||||
for i in range(TERRAIN_LENGTH):
|
for i in range(TERRAIN_LENGTH):
|
||||||
x = i * TERRAIN_STEP
|
x = i * TERRAIN_STEP
|
||||||
self.terrain_x.append(x)
|
self.terrain_x.append(x)
|
||||||
@@ -448,8 +454,8 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
(self.np_random.uniform(-INITIAL_RANDOM, INITIAL_RANDOM), 0), True
|
(self.np_random.uniform(-INITIAL_RANDOM, INITIAL_RANDOM), 0), True
|
||||||
)
|
)
|
||||||
|
|
||||||
self.legs = []
|
self.legs: List[Box2D.b2Body] = []
|
||||||
self.joints = []
|
self.joints: List[Box2D.b2RevoluteJoint] = []
|
||||||
for i in [-1, +1]:
|
for i in [-1, +1]:
|
||||||
leg = self.world.CreateDynamicBody(
|
leg = self.world.CreateDynamicBody(
|
||||||
position=(init_x, init_y - LEG_H / 2 - LEG_DOWN),
|
position=(init_x, init_y - LEG_H / 2 - LEG_DOWN),
|
||||||
@@ -514,6 +520,8 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
return self.step(np.array([0, 0, 0, 0]))[0], {}
|
return self.step(np.array([0, 0, 0, 0]))[0], {}
|
||||||
|
|
||||||
def step(self, action: np.ndarray):
|
def step(self, action: np.ndarray):
|
||||||
|
assert self.hull is not None
|
||||||
|
|
||||||
# self.hull.ApplyForceToCenter((0, 20), True) -- Uncomment this to receive a bit of stability help
|
# self.hull.ApplyForceToCenter((0, 20), True) -- Uncomment this to receive a bit of stability help
|
||||||
control_speed = False # Should be easier as well
|
control_speed = False # Should be easier as well
|
||||||
if control_speed:
|
if control_speed:
|
||||||
@@ -737,6 +745,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
self.surf = pygame.transform.flip(self.surf, False, True)
|
self.surf = pygame.transform.flip(self.surf, False, True)
|
||||||
|
|
||||||
if mode == "human":
|
if mode == "human":
|
||||||
|
assert self.screen is not None
|
||||||
self.screen.blit(self.surf, (-self.scroll * SCALE, 0))
|
self.screen.blit(self.surf, (-self.scroll * SCALE, 0))
|
||||||
pygame.event.pump()
|
pygame.event.pump()
|
||||||
self.clock.tick(self.metadata["render_fps"])
|
self.clock.tick(self.metadata["render_fps"])
|
||||||
|
@@ -9,6 +9,7 @@ Created by Oleg Klimov
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
import Box2D
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from gym.error import DependencyNotInstalled
|
from gym.error import DependencyNotInstalled
|
||||||
@@ -48,8 +49,8 @@ MUD_COLOR = (102, 102, 0)
|
|||||||
|
|
||||||
class Car:
|
class Car:
|
||||||
def __init__(self, world, init_angle, init_x, init_y):
|
def __init__(self, world, init_angle, init_x, init_y):
|
||||||
self.world = world
|
self.world: Box2D.b2World = world
|
||||||
self.hull = self.world.CreateDynamicBody(
|
self.hull: Box2D.b2Body = self.world.CreateDynamicBody(
|
||||||
position=(init_x, init_y),
|
position=(init_x, init_y),
|
||||||
angle=init_angle,
|
angle=init_angle,
|
||||||
fixtures=[
|
fixtures=[
|
||||||
|
@@ -197,14 +197,14 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
|
|
||||||
self.contactListener_keepref = FrictionDetector(self, lap_complete_percent)
|
self.contactListener_keepref = FrictionDetector(self, lap_complete_percent)
|
||||||
self.world = Box2D.b2World((0, 0), contactListener=self.contactListener_keepref)
|
self.world = Box2D.b2World((0, 0), contactListener=self.contactListener_keepref)
|
||||||
self.screen = None
|
self.screen: Optional[pygame.Surface] = None
|
||||||
self.surf = None
|
self.surf = None
|
||||||
self.clock = None
|
self.clock = None
|
||||||
self.isopen = True
|
self.isopen = True
|
||||||
self.invisible_state_window = None
|
self.invisible_state_window = None
|
||||||
self.invisible_video_window = None
|
self.invisible_video_window = None
|
||||||
self.road = None
|
self.road = None
|
||||||
self.car = None
|
self.car: Optional[Car] = None
|
||||||
self.reward = 0.0
|
self.reward = 0.0
|
||||||
self.prev_reward = 0.0
|
self.prev_reward = 0.0
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
@@ -237,6 +237,7 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
for t in self.road:
|
for t in self.road:
|
||||||
self.world.DestroyBody(t)
|
self.world.DestroyBody(t)
|
||||||
self.road = []
|
self.road = []
|
||||||
|
assert self.car is not None
|
||||||
self.car.destroy()
|
self.car.destroy()
|
||||||
|
|
||||||
def _init_colors(self):
|
def _init_colors(self):
|
||||||
@@ -502,6 +503,7 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
return self.step(None)[0], {}
|
return self.step(None)[0], {}
|
||||||
|
|
||||||
def step(self, action: Union[np.ndarray, int]):
|
def step(self, action: Union[np.ndarray, int]):
|
||||||
|
assert self.car is not None
|
||||||
if action is not None:
|
if action is not None:
|
||||||
if self.continuous:
|
if self.continuous:
|
||||||
self.car.steer(-action[0])
|
self.car.steer(-action[0])
|
||||||
@@ -576,6 +578,7 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
|
|
||||||
self.surf = pygame.Surface((WINDOW_W, WINDOW_H))
|
self.surf = pygame.Surface((WINDOW_W, WINDOW_H))
|
||||||
|
|
||||||
|
assert self.car is not None
|
||||||
# computing transformations
|
# computing transformations
|
||||||
angle = -self.car.hull.angle
|
angle = -self.car.hull.angle
|
||||||
# Animating first second zoom.
|
# Animating first second zoom.
|
||||||
@@ -608,6 +611,7 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
if mode == "human":
|
if mode == "human":
|
||||||
pygame.event.pump()
|
pygame.event.pump()
|
||||||
self.clock.tick(self.metadata["render_fps"])
|
self.clock.tick(self.metadata["render_fps"])
|
||||||
|
assert self.screen is not None
|
||||||
self.screen.fill(0)
|
self.screen.fill(0)
|
||||||
self.screen.blit(self.surf, (0, 0))
|
self.screen.blit(self.surf, (0, 0))
|
||||||
pygame.display.flip()
|
pygame.display.flip()
|
||||||
@@ -682,6 +686,7 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
((place + 0) * s, H - 2 * h),
|
((place + 0) * s, H - 2 * h),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
assert self.car is not None
|
||||||
true_speed = np.sqrt(
|
true_speed = np.sqrt(
|
||||||
np.square(self.car.hull.linearVelocity[0])
|
np.square(self.car.hull.linearVelocity[0])
|
||||||
+ np.square(self.car.hull.linearVelocity[1])
|
+ np.square(self.car.hull.linearVelocity[1])
|
||||||
|
@@ -2,7 +2,7 @@ __credits__ = ["Andrea PIERRÉ"]
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -25,6 +25,11 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise DependencyNotInstalled("box2d is not installed, run `pip install gym[box2d]`")
|
raise DependencyNotInstalled("box2d is not installed, run `pip install gym[box2d]`")
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import pygame
|
||||||
|
|
||||||
|
|
||||||
FPS = 50
|
FPS = 50
|
||||||
SCALE = 30.0 # affects how fast-paced the game is, forces should be adjusted as well
|
SCALE = 30.0 # affects how fast-paced the game is, forces should be adjusted as well
|
||||||
|
|
||||||
@@ -215,7 +220,7 @@ class LunarLander(gym.Env, EzPickle):
|
|||||||
self.wind_idx = np.random.randint(-9999, 9999)
|
self.wind_idx = np.random.randint(-9999, 9999)
|
||||||
self.torque_idx = np.random.randint(-9999, 9999)
|
self.torque_idx = np.random.randint(-9999, 9999)
|
||||||
|
|
||||||
self.screen = None
|
self.screen: pygame.Surface = None
|
||||||
self.clock = None
|
self.clock = None
|
||||||
self.isopen = True
|
self.isopen = True
|
||||||
self.world = Box2D.b2World(gravity=(0, gravity))
|
self.world = Box2D.b2World(gravity=(0, gravity))
|
||||||
@@ -427,6 +432,8 @@ class LunarLander(gym.Env, EzPickle):
|
|||||||
self.world.DestroyBody(self.particles.pop(0))
|
self.world.DestroyBody(self.particles.pop(0))
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
assert self.lander is not None
|
||||||
|
|
||||||
# Update wind
|
# Update wind
|
||||||
assert self.lander is not None, "You forgot to call reset()"
|
assert self.lander is not None, "You forgot to call reset()"
|
||||||
if self.enable_wind and not (
|
if self.enable_wind and not (
|
||||||
@@ -604,10 +611,6 @@ class LunarLander(gym.Env, EzPickle):
|
|||||||
if self.clock is None:
|
if self.clock is None:
|
||||||
self.clock = pygame.time.Clock()
|
self.clock = pygame.time.Clock()
|
||||||
|
|
||||||
assert (
|
|
||||||
self.screen is not None
|
|
||||||
), "Something went wrong with pygame, there is no screen to render"
|
|
||||||
|
|
||||||
self.surf = pygame.Surface((VIEWPORT_W, VIEWPORT_H))
|
self.surf = pygame.Surface((VIEWPORT_W, VIEWPORT_H))
|
||||||
|
|
||||||
pygame.transform.scale(self.surf, (SCALE, SCALE))
|
pygame.transform.scale(self.surf, (SCALE, SCALE))
|
||||||
|
@@ -79,4 +79,5 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
self.viewer.cam.distance = self.model.stat.extent * 0.5
|
self.viewer.cam.distance = self.model.stat.extent * 0.5
|
||||||
|
@@ -173,6 +173,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return observation
|
return observation
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||||
if isinstance(value, np.ndarray):
|
if isinstance(value, np.ndarray):
|
||||||
getattr(self.viewer.cam, key)[:] = value
|
getattr(self.viewer.cam, key)[:] = value
|
||||||
|
@@ -336,6 +336,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return observation
|
return observation
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||||
if isinstance(value, np.ndarray):
|
if isinstance(value, np.ndarray):
|
||||||
getattr(self.viewer.cam, key)[:] = value
|
getattr(self.viewer.cam, key)[:] = value
|
||||||
|
@@ -60,4 +60,5 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
self.viewer.cam.distance = self.model.stat.extent * 0.5
|
self.viewer.cam.distance = self.model.stat.extent * 0.5
|
||||||
|
@@ -118,6 +118,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return observation
|
return observation
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||||
if isinstance(value, np.ndarray):
|
if isinstance(value, np.ndarray):
|
||||||
getattr(self.viewer.cam, key)[:] = value
|
getattr(self.viewer.cam, key)[:] = value
|
||||||
|
@@ -231,6 +231,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return observation
|
return observation
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||||
if isinstance(value, np.ndarray):
|
if isinstance(value, np.ndarray):
|
||||||
getattr(self.viewer.cam, key)[:] = value
|
getattr(self.viewer.cam, key)[:] = value
|
||||||
|
@@ -66,6 +66,7 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
self.viewer.cam.trackbodyid = 2
|
self.viewer.cam.trackbodyid = 2
|
||||||
self.viewer.cam.distance = self.model.stat.extent * 0.75
|
self.viewer.cam.distance = self.model.stat.extent * 0.75
|
||||||
self.viewer.cam.lookat[2] = 1.15
|
self.viewer.cam.lookat[2] = 1.15
|
||||||
|
@@ -163,6 +163,7 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return observation
|
return observation
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||||
if isinstance(value, np.ndarray):
|
if isinstance(value, np.ndarray):
|
||||||
getattr(self.viewer.cam, key)[:] = value
|
getattr(self.viewer.cam, key)[:] = value
|
||||||
|
@@ -279,6 +279,7 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return observation
|
return observation
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||||
if isinstance(value, np.ndarray):
|
if isinstance(value, np.ndarray):
|
||||||
getattr(self.viewer.cam, key)[:] = value
|
getattr(self.viewer.cam, key)[:] = value
|
||||||
|
@@ -92,6 +92,7 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
self.viewer.cam.trackbodyid = 1
|
self.viewer.cam.trackbodyid = 1
|
||||||
self.viewer.cam.distance = self.model.stat.extent * 1.0
|
self.viewer.cam.distance = self.model.stat.extent * 1.0
|
||||||
self.viewer.cam.lookat[2] = 2.0
|
self.viewer.cam.lookat[2] = 2.0
|
||||||
|
@@ -185,6 +185,7 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return observation
|
return observation
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||||
if isinstance(value, np.ndarray):
|
if isinstance(value, np.ndarray):
|
||||||
getattr(self.viewer.cam, key)[:] = value
|
getattr(self.viewer.cam, key)[:] = value
|
||||||
|
@@ -358,6 +358,7 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return observation
|
return observation
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||||
if isinstance(value, np.ndarray):
|
if isinstance(value, np.ndarray):
|
||||||
getattr(self.viewer.cam, key)[:] = value
|
getattr(self.viewer.cam, key)[:] = value
|
||||||
|
@@ -83,6 +83,7 @@ class HumanoidStandupEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
self.viewer.cam.trackbodyid = 1
|
self.viewer.cam.trackbodyid = 1
|
||||||
self.viewer.cam.distance = self.model.stat.extent * 1.0
|
self.viewer.cam.distance = self.model.stat.extent * 1.0
|
||||||
self.viewer.cam.lookat[2] = 0.8925
|
self.viewer.cam.lookat[2] = 0.8925
|
||||||
|
@@ -254,6 +254,7 @@ class HumanoidStandupEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
self.viewer.cam.trackbodyid = 1
|
self.viewer.cam.trackbodyid = 1
|
||||||
self.viewer.cam.distance = self.model.stat.extent * 1.0
|
self.viewer.cam.distance = self.model.stat.extent * 1.0
|
||||||
self.viewer.cam.lookat[2] = 0.8925
|
self.viewer.cam.lookat[2] = 0.8925
|
||||||
|
@@ -64,6 +64,7 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
v = self.viewer
|
v = self.viewer
|
||||||
v.cam.trackbodyid = 0
|
v.cam.trackbodyid = 0
|
||||||
v.cam.distance = self.model.stat.extent * 0.5
|
v.cam.distance = self.model.stat.extent * 0.5
|
||||||
|
@@ -169,6 +169,7 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
v = self.viewer
|
v = self.viewer
|
||||||
v.cam.trackbodyid = 0
|
v.cam.trackbodyid = 0
|
||||||
v.cam.distance = self.model.stat.extent * 0.5
|
v.cam.distance = self.model.stat.extent * 0.5
|
||||||
|
@@ -54,6 +54,6 @@ class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return np.concatenate([self.sim.data.qpos, self.sim.data.qvel]).ravel()
|
return np.concatenate([self.sim.data.qpos, self.sim.data.qvel]).ravel()
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
v = self.viewer
|
assert self.viewer is not None
|
||||||
v.cam.trackbodyid = 0
|
self.viewer.cam.trackbodyid = 0
|
||||||
v.cam.distance = self.model.stat.extent
|
self.viewer.cam.distance = self.model.stat.extent
|
||||||
|
@@ -130,6 +130,7 @@ class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return np.concatenate([self.data.qpos, self.data.qvel]).ravel()
|
return np.concatenate([self.data.qpos, self.data.qvel]).ravel()
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
v = self.viewer
|
v = self.viewer
|
||||||
v.cam.trackbodyid = 0
|
v.cam.trackbodyid = 0
|
||||||
v.cam.distance = self.model.stat.extent
|
v.cam.distance = self.model.stat.extent
|
||||||
|
@@ -47,6 +47,7 @@ class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
|
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
self.viewer.cam.trackbodyid = -1
|
self.viewer.cam.trackbodyid = -1
|
||||||
self.viewer.cam.distance = 4.0
|
self.viewer.cam.distance = 4.0
|
||||||
|
|
||||||
|
@@ -164,6 +164,7 @@ class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
|
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
self.viewer.cam.trackbodyid = -1
|
self.viewer.cam.trackbodyid = -1
|
||||||
self.viewer.cam.distance = 4.0
|
self.viewer.cam.distance = 4.0
|
||||||
|
|
||||||
|
@@ -43,6 +43,7 @@ class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
|
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
self.viewer.cam.trackbodyid = 0
|
self.viewer.cam.trackbodyid = 0
|
||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
|
@@ -150,6 +150,7 @@ class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
|
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
self.viewer.cam.trackbodyid = 0
|
self.viewer.cam.trackbodyid = 0
|
||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
|
@@ -119,6 +119,7 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return observation
|
return observation
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||||
if isinstance(value, np.ndarray):
|
if isinstance(value, np.ndarray):
|
||||||
getattr(self.viewer.cam, key)[:] = value
|
getattr(self.viewer.cam, key)[:] = value
|
||||||
|
@@ -225,6 +225,7 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return observation
|
return observation
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||||
if isinstance(value, np.ndarray):
|
if isinstance(value, np.ndarray):
|
||||||
getattr(self.viewer.cam, key)[:] = value
|
getattr(self.viewer.cam, key)[:] = value
|
||||||
|
@@ -60,6 +60,7 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
self.viewer.cam.trackbodyid = 2
|
self.viewer.cam.trackbodyid = 2
|
||||||
self.viewer.cam.distance = self.model.stat.extent * 0.5
|
self.viewer.cam.distance = self.model.stat.extent * 0.5
|
||||||
self.viewer.cam.lookat[2] = 1.15
|
self.viewer.cam.lookat[2] = 1.15
|
||||||
|
@@ -153,6 +153,7 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return observation
|
return observation
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||||
if isinstance(value, np.ndarray):
|
if isinstance(value, np.ndarray):
|
||||||
getattr(self.viewer.cam, key)[:] = value
|
getattr(self.viewer.cam, key)[:] = value
|
||||||
|
@@ -277,6 +277,7 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
return observation
|
return observation
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
for key, value in DEFAULT_CAMERA_CONFIG.items():
|
||||||
if isinstance(value, np.ndarray):
|
if isinstance(value, np.ndarray):
|
||||||
getattr(self.viewer.cam, key)[:] = value
|
getattr(self.viewer.cam, key)[:] = value
|
||||||
|
@@ -134,10 +134,12 @@ class Graph(Space):
|
|||||||
assert (
|
assert (
|
||||||
num_edges >= 0
|
num_edges >= 0
|
||||||
), f"The number of edges is expected to be greater than 0, actual mask: {num_edges}"
|
), f"The number of edges is expected to be greater than 0, actual mask: {num_edges}"
|
||||||
|
assert num_edges is not None
|
||||||
|
|
||||||
sampled_node_space = self._generate_sample_space(self.node_space, num_nodes)
|
sampled_node_space = self._generate_sample_space(self.node_space, num_nodes)
|
||||||
sampled_edge_space = self._generate_sample_space(self.edge_space, num_edges)
|
sampled_edge_space = self._generate_sample_space(self.edge_space, num_edges)
|
||||||
|
|
||||||
|
assert sampled_node_space is not None
|
||||||
sampled_nodes = sampled_node_space.sample(node_space_mask)
|
sampled_nodes = sampled_node_space.sample(node_space_mask)
|
||||||
sampled_edges = (
|
sampled_edges = (
|
||||||
sampled_edge_space.sample(edge_space_mask)
|
sampled_edge_space.sample(edge_space_mask)
|
||||||
|
@@ -3,16 +3,23 @@ from collections import deque
|
|||||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pygame
|
|
||||||
from pygame import Surface
|
|
||||||
from pygame.event import Event
|
|
||||||
from pygame.locals import VIDEORESIZE
|
|
||||||
|
|
||||||
|
import gym.error
|
||||||
from gym import Env, logger
|
from gym import Env, logger
|
||||||
from gym.core import ActType, ObsType
|
from gym.core import ActType, ObsType
|
||||||
from gym.error import DependencyNotInstalled
|
from gym.error import DependencyNotInstalled
|
||||||
from gym.logger import deprecation
|
from gym.logger import deprecation
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pygame
|
||||||
|
from pygame import Surface
|
||||||
|
from pygame.event import Event
|
||||||
|
from pygame.locals import VIDEORESIZE
|
||||||
|
except ImportError:
|
||||||
|
raise gym.error.DependencyNotInstalled(
|
||||||
|
"Pygame is not installed, run `pip install gym[classic_control]`"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import matplotlib
|
import matplotlib
|
||||||
|
|
||||||
@@ -20,7 +27,7 @@ try:
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warn("Matplotlib is not installed, run `pip install gym[other]`")
|
logger.warn("Matplotlib is not installed, run `pip install gym[other]`")
|
||||||
matplotlib, plt = None, None
|
plt = None
|
||||||
|
|
||||||
|
|
||||||
class MissingKeysToAction(Exception):
|
class MissingKeysToAction(Exception):
|
||||||
@@ -33,7 +40,7 @@ class PlayableGame:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env: Env,
|
env: Env,
|
||||||
keys_to_action: Optional[Dict[Tuple[int], int]] = None,
|
keys_to_action: Optional[Dict[Tuple[int, ...], int]] = None,
|
||||||
zoom: Optional[float] = None,
|
zoom: Optional[float] = None,
|
||||||
):
|
):
|
||||||
"""Wraps an environment with a dictionary of keyboard buttons to action and if to zoom in on the environment.
|
"""Wraps an environment with a dictionary of keyboard buttons to action and if to zoom in on the environment.
|
||||||
@@ -63,12 +70,14 @@ class PlayableGame:
|
|||||||
f"{self.env.spec.id} does not have explicit key to action mapping, "
|
f"{self.env.spec.id} does not have explicit key to action mapping, "
|
||||||
"please specify one manually"
|
"please specify one manually"
|
||||||
)
|
)
|
||||||
|
assert isinstance(keys_to_action, dict)
|
||||||
relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), []))
|
relevant_keys = set(sum((list(k) for k in keys_to_action.keys()), []))
|
||||||
return relevant_keys
|
return relevant_keys
|
||||||
|
|
||||||
def _get_video_size(self, zoom: Optional[float] = None) -> Tuple[int, int]:
|
def _get_video_size(self, zoom: Optional[float] = None) -> Tuple[int, int]:
|
||||||
# TODO: this needs to be updated when the render API change goes through
|
# TODO: this needs to be updated when the render API change goes through
|
||||||
rendered = self.env.render(mode="rgb_array")
|
rendered = self.env.render(mode="rgb_array")
|
||||||
|
assert rendered is not None and isinstance(rendered, np.ndarray)
|
||||||
video_size = [rendered.shape[1], rendered.shape[0]]
|
video_size = [rendered.shape[1], rendered.shape[0]]
|
||||||
|
|
||||||
if zoom is not None:
|
if zoom is not None:
|
||||||
@@ -211,9 +220,9 @@ def play(
|
|||||||
f"{env.spec.id} does not have explicit key to action mapping, "
|
f"{env.spec.id} does not have explicit key to action mapping, "
|
||||||
"please specify one manually"
|
"please specify one manually"
|
||||||
)
|
)
|
||||||
|
assert keys_to_action is not None
|
||||||
|
|
||||||
key_code_to_action = {}
|
key_code_to_action = {}
|
||||||
|
|
||||||
for key_combination, action in keys_to_action.items():
|
for key_combination, action in keys_to_action.items():
|
||||||
key_code = tuple(
|
key_code = tuple(
|
||||||
sorted(ord(key) if isinstance(key, str) else key for key in key_combination)
|
sorted(ord(key) if isinstance(key, str) else key for key in key_combination)
|
||||||
@@ -225,7 +234,7 @@ def play(
|
|||||||
if fps is None:
|
if fps is None:
|
||||||
fps = env.metadata.get("render_fps", 30)
|
fps = env.metadata.get("render_fps", 30)
|
||||||
|
|
||||||
done = True
|
done, obs = True, None
|
||||||
clock = pygame.time.Clock()
|
clock = pygame.time.Clock()
|
||||||
|
|
||||||
while game.running:
|
while game.running:
|
||||||
@@ -316,7 +325,7 @@ class PlayPlot:
|
|||||||
for axis, name in zip(self.ax, plot_names):
|
for axis, name in zip(self.ax, plot_names):
|
||||||
axis.set_title(name)
|
axis.set_title(name)
|
||||||
self.t = 0
|
self.t = 0
|
||||||
self.cur_plot = [None for _ in range(num_plots)]
|
self.cur_plot: List[Optional[plt.Axes]] = [None for _ in range(num_plots)]
|
||||||
self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)]
|
self.data = [deque(maxlen=horizon_timesteps) for _ in range(num_plots)]
|
||||||
|
|
||||||
def callback(
|
def callback(
|
||||||
@@ -352,4 +361,9 @@ class PlayPlot:
|
|||||||
range(xmin, xmax), list(self.data[i]), c="blue"
|
range(xmin, xmax), list(self.data[i]), c="blue"
|
||||||
)
|
)
|
||||||
self.ax[i].set_xlim(xmin, xmax)
|
self.ax[i].set_xlim(xmin, xmax)
|
||||||
|
|
||||||
|
if plt is None:
|
||||||
|
raise DependencyNotInstalled(
|
||||||
|
"matplotlib is not installed, run `pip install gym[other]`"
|
||||||
|
)
|
||||||
plt.pause(0.000001)
|
plt.pause(0.000001)
|
||||||
|
@@ -569,7 +569,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
|
|
||||||
num_errors = self.num_envs - sum(successes)
|
num_errors = self.num_envs - sum(successes)
|
||||||
assert num_errors > 0
|
assert num_errors > 0
|
||||||
for _ in range(num_errors):
|
for i in range(num_errors):
|
||||||
index, exctype, value = self.error_queue.get()
|
index, exctype, value = self.error_queue.get()
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Received the following error from Worker-{index}: {exctype.__name__}: {value}"
|
f"Received the following error from Worker-{index}: {exctype.__name__}: {value}"
|
||||||
@@ -578,6 +578,7 @@ class AsyncVectorEnv(VectorEnv):
|
|||||||
self.parent_pipes[index].close()
|
self.parent_pipes[index].close()
|
||||||
self.parent_pipes[index] = None
|
self.parent_pipes[index] = None
|
||||||
|
|
||||||
|
if i == num_errors - 1:
|
||||||
logger.error("Raising the last exception back to the main process.")
|
logger.error("Raising the last exception back to the main process.")
|
||||||
raise exctype(value)
|
raise exctype(value)
|
||||||
|
|
||||||
|
@@ -1,9 +1,10 @@
|
|||||||
"""A synchronous vector environment."""
|
"""A synchronous vector environment."""
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Iterator, List, Optional, Sequence, Union
|
from typing import Any, Callable, Iterator, List, Optional, Sequence, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from gym import Env
|
||||||
from gym.spaces import Space
|
from gym.spaces import Space
|
||||||
from gym.vector.utils import concatenate, create_empty_array, iterate
|
from gym.vector.utils import concatenate, create_empty_array, iterate
|
||||||
from gym.vector.vector_env import VectorEnv
|
from gym.vector.vector_env import VectorEnv
|
||||||
@@ -28,7 +29,7 @@ class SyncVectorEnv(VectorEnv):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
env_fns: Iterator[callable],
|
env_fns: Iterator[Callable[[], Env]],
|
||||||
observation_space: Space = None,
|
observation_space: Space = None,
|
||||||
action_space: Space = None,
|
action_space: Space = None,
|
||||||
copy: bool = True,
|
copy: bool = True,
|
||||||
|
@@ -6,7 +6,8 @@ from gym.vector.utils.shared_memory import (
|
|||||||
read_from_shared_memory,
|
read_from_shared_memory,
|
||||||
write_to_shared_memory,
|
write_to_shared_memory,
|
||||||
)
|
)
|
||||||
from gym.vector.utils.spaces import BaseGymSpaces, _BaseGymSpaces, batch_space, iterate
|
from gym.vector.utils.spaces import _BaseGymSpaces # pyright: reportPrivateUsage=false
|
||||||
|
from gym.vector.utils.spaces import BaseGymSpaces, batch_space, iterate
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CloudpickleWrapper",
|
"CloudpickleWrapper",
|
||||||
@@ -17,7 +18,6 @@ __all__ = [
|
|||||||
"read_from_shared_memory",
|
"read_from_shared_memory",
|
||||||
"write_to_shared_memory",
|
"write_to_shared_memory",
|
||||||
"BaseGymSpaces",
|
"BaseGymSpaces",
|
||||||
"_BaseGymSpaces",
|
|
||||||
"batch_space",
|
"batch_space",
|
||||||
"iterate",
|
"iterate",
|
||||||
]
|
]
|
||||||
|
@@ -85,6 +85,7 @@ class VectorEnv(gym.Env):
|
|||||||
Raises:
|
Raises:
|
||||||
NotImplementedError: VectorEnv does not implement function
|
NotImplementedError: VectorEnv does not implement function
|
||||||
"""
|
"""
|
||||||
|
raise NotImplementedError("VectorEnv does not implement function")
|
||||||
|
|
||||||
def reset(
|
def reset(
|
||||||
self,
|
self,
|
||||||
|
@@ -2,13 +2,14 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym.error import DependencyNotInstalled
|
|
||||||
from gym.spaces import Box
|
from gym.spaces import Box
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import cv2
|
import cv2
|
||||||
except ImportError:
|
except ImportError:
|
||||||
cv2 = None
|
raise gym.error.DependencyNotInstalled(
|
||||||
|
"opencv-python package not installed, run `pip install gym[other]` to get dependencies for atari"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AtariPreprocessing(gym.Wrapper):
|
class AtariPreprocessing(gym.Wrapper):
|
||||||
@@ -60,10 +61,6 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
ValueError: Disable frame-skipping in the original env
|
ValueError: Disable frame-skipping in the original env
|
||||||
"""
|
"""
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
if cv2 is None:
|
|
||||||
raise DependencyNotInstalled(
|
|
||||||
"opencv-python package not installed, run `pip install gym[other]` to get dependencies for atari"
|
|
||||||
)
|
|
||||||
assert frame_skip > 0
|
assert frame_skip > 0
|
||||||
assert screen_size > 0
|
assert screen_size > 0
|
||||||
assert noop_max >= 0
|
assert noop_max >= 0
|
||||||
@@ -87,6 +84,7 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
self.scale_obs = scale_obs
|
self.scale_obs = scale_obs
|
||||||
|
|
||||||
# buffer of most recent two observations for max pooling
|
# buffer of most recent two observations for max pooling
|
||||||
|
assert isinstance(env.observation_space, Box)
|
||||||
if grayscale_obs:
|
if grayscale_obs:
|
||||||
self.obs_buffer = [
|
self.obs_buffer = [
|
||||||
np.empty(env.observation_space.shape[:2], dtype=np.uint8),
|
np.empty(env.observation_space.shape[:2], dtype=np.uint8),
|
||||||
@@ -114,7 +112,7 @@ class AtariPreprocessing(gym.Wrapper):
|
|||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""Applies the preprocessing for an :meth:`env.step`."""
|
"""Applies the preprocessing for an :meth:`env.step`."""
|
||||||
total_reward = 0.0
|
total_reward, done, info = 0.0, False, {}
|
||||||
|
|
||||||
for t in range(self.frame_skip):
|
for t in range(self.frame_skip):
|
||||||
_, reward, done, info = self.env.step(action)
|
_, reward, done, info = self.env.step(action)
|
||||||
|
@@ -32,8 +32,9 @@ class GrayScaleObservation(gym.ObservationWrapper):
|
|||||||
self.keep_dim = keep_dim
|
self.keep_dim = keep_dim
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
len(env.observation_space.shape) == 3
|
isinstance(self.observation_space, Box)
|
||||||
and env.observation_space.shape[-1] == 3
|
and len(self.observation_space.shape) == 3
|
||||||
|
and self.observation_space.shape[-1] == 3
|
||||||
)
|
)
|
||||||
|
|
||||||
obs_shape = self.observation_space.shape[:2]
|
obs_shape = self.observation_space.shape[:2]
|
||||||
|
@@ -88,13 +88,16 @@ class HumanRendering(gym.Wrapper):
|
|||||||
"pygame is not installed, run `pip install gym[box2d]`"
|
"pygame is not installed, run `pip install gym[box2d]`"
|
||||||
)
|
)
|
||||||
if self.env.render_mode == "rgb_array":
|
if self.env.render_mode == "rgb_array":
|
||||||
last_rgb_array = self.env.render(**kwargs)[-1]
|
last_rgb_array = self.env.render(**kwargs)
|
||||||
|
assert isinstance(last_rgb_array, list)
|
||||||
|
last_rgb_array = last_rgb_array[-1]
|
||||||
elif self.env.render_mode == "single_rgb_array":
|
elif self.env.render_mode == "single_rgb_array":
|
||||||
last_rgb_array = self.env.render(**kwargs)
|
last_rgb_array = self.env.render(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Wrapped environment must have mode 'rgb_array' or 'single_rgb_array'"
|
f"Wrapped environment must have mode 'rgb_array' or 'single_rgb_array', actual render mode: {self.env.render_mode}"
|
||||||
)
|
)
|
||||||
|
assert isinstance(last_rgb_array, np.ndarray)
|
||||||
|
|
||||||
if mode == "human":
|
if mode == "human":
|
||||||
rgb_array = np.transpose(last_rgb_array, axes=(1, 0, 2))
|
rgb_array = np.transpose(last_rgb_array, axes=(1, 0, 2))
|
||||||
|
@@ -148,7 +148,9 @@ class VideoRecorder: # TODO: remove with gym 1.0
|
|||||||
)
|
)
|
||||||
self.output_frames_per_sec = self.backward_compatible_output_frames_per_sec
|
self.output_frames_per_sec = self.backward_compatible_output_frames_per_sec
|
||||||
|
|
||||||
self.encoder = None # lazily start the process
|
self.encoder: Optional[
|
||||||
|
Union[TextEncoder, ImageEncoder]
|
||||||
|
] = None # lazily start the process
|
||||||
self.broken = False
|
self.broken = False
|
||||||
|
|
||||||
# Dump metadata
|
# Dump metadata
|
||||||
@@ -387,7 +389,7 @@ class ImageEncoder:
|
|||||||
InvalidFrame: Expects frame to have shape (w,h,3) or (w,h,4)
|
InvalidFrame: Expects frame to have shape (w,h,3) or (w,h,4)
|
||||||
DependencyNotInstalled: Found neither the ffmpeg nor avconv executables.
|
DependencyNotInstalled: Found neither the ffmpeg nor avconv executables.
|
||||||
"""
|
"""
|
||||||
self.proc = None
|
self.proc: Optional[subprocess.Popen] = None
|
||||||
self.output_path = output_path
|
self.output_path = output_path
|
||||||
# Frame shape should be lines-first, so w and h are swapped
|
# Frame shape should be lines-first, so w and h are swapped
|
||||||
h, w, pixfmt = frame_shape
|
h, w, pixfmt = frame_shape
|
||||||
@@ -488,6 +490,7 @@ class ImageEncoder:
|
|||||||
f"Your frame has data type {frame.dtype}, but we require uint8 (i.e. RGB values from 0-255)."
|
f"Your frame has data type {frame.dtype}, but we require uint8 (i.e. RGB values from 0-255)."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert self.proc is not None and self.proc.stdin is not None
|
||||||
try:
|
try:
|
||||||
self.proc.stdin.write(frame.tobytes())
|
self.proc.stdin.write(frame.tobytes())
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -496,6 +499,7 @@ class ImageEncoder:
|
|||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Closes the Image encoder."""
|
"""Closes the Image encoder."""
|
||||||
|
assert self.proc is not None and self.proc.stdin is not None
|
||||||
self.proc.stdin.close()
|
self.proc.stdin.close()
|
||||||
ret = self.proc.wait()
|
ret = self.proc.wait()
|
||||||
if ret != 0:
|
if ret != 0:
|
||||||
|
@@ -81,19 +81,20 @@ class NormalizeObservation(gym.core.Wrapper):
|
|||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
"""Resets the environment and normalizes the observation."""
|
"""Resets the environment and normalizes the observation."""
|
||||||
return_info = kwargs.get("return_info", False)
|
if kwargs.get("return_info", False):
|
||||||
if return_info:
|
|
||||||
obs, info = self.env.reset(**kwargs)
|
obs, info = self.env.reset(**kwargs)
|
||||||
|
|
||||||
|
if self.is_vector_env:
|
||||||
|
return self.normalize(obs), info
|
||||||
|
else:
|
||||||
|
return self.normalize(np.array([obs]))[0], info
|
||||||
else:
|
else:
|
||||||
obs = self.env.reset(**kwargs)
|
obs = self.env.reset(**kwargs)
|
||||||
|
|
||||||
if self.is_vector_env:
|
if self.is_vector_env:
|
||||||
obs = self.normalize(obs)
|
return self.normalize(obs)
|
||||||
else:
|
else:
|
||||||
obs = self.normalize(np.array([obs]))[0]
|
return self.normalize(np.array([obs]))[0]
|
||||||
if not return_info:
|
|
||||||
return obs
|
|
||||||
else:
|
|
||||||
return obs, info
|
|
||||||
|
|
||||||
def normalize(self, obs):
|
def normalize(self, obs):
|
||||||
"""Normalises the observation using the running mean and variance of the observations."""
|
"""Normalises the observation using the running mean and variance of the observations."""
|
||||||
|
@@ -49,3 +49,8 @@ class OrderEnforcing(gym.Wrapper):
|
|||||||
"set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper."
|
"set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper."
|
||||||
)
|
)
|
||||||
return self.env.render(*args, **kwargs)
|
return self.env.render(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_reset(self):
|
||||||
|
"""Returns if the environment has been reset before."""
|
||||||
|
return self._has_reset
|
||||||
|
@@ -120,8 +120,7 @@ class PixelObservationWrapper(gym.ObservationWrapper):
|
|||||||
elif self._observation_is_dict:
|
elif self._observation_is_dict:
|
||||||
self.observation_space = copy.deepcopy(wrapped_observation_space)
|
self.observation_space = copy.deepcopy(wrapped_observation_space)
|
||||||
else:
|
else:
|
||||||
self.observation_space = spaces.Dict()
|
self.observation_space = spaces.Dict({STATE_KEY: wrapped_observation_space})
|
||||||
self.observation_space.spaces[STATE_KEY] = wrapped_observation_space
|
|
||||||
|
|
||||||
# Extend observation space with pixels.
|
# Extend observation space with pixels.
|
||||||
|
|
||||||
@@ -129,7 +128,7 @@ class PixelObservationWrapper(gym.ObservationWrapper):
|
|||||||
pixels_spaces = {}
|
pixels_spaces = {}
|
||||||
for pixel_key in pixel_keys:
|
for pixel_key in pixel_keys:
|
||||||
pixels = self.env.render(**render_kwargs[pixel_key])
|
pixels = self.env.render(**render_kwargs[pixel_key])
|
||||||
pixels = pixels[-1] if isinstance(pixels, List) else pixels
|
pixels: np.ndarray = pixels[-1] if isinstance(pixels, List) else pixels
|
||||||
|
|
||||||
if np.issubdtype(pixels.dtype, np.integer):
|
if np.issubdtype(pixels.dtype, np.integer):
|
||||||
low, high = (0, 255)
|
low, high = (0, 255)
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
"""Wrapper that tracks the cumulative rewards and episode lengths."""
|
"""Wrapper that tracks the cumulative rewards and episode lengths."""
|
||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -86,8 +87,8 @@ class RecordEpisodeStatistics(gym.Wrapper):
|
|||||||
self.num_envs = getattr(env, "num_envs", 1)
|
self.num_envs = getattr(env, "num_envs", 1)
|
||||||
self.t0 = time.perf_counter()
|
self.t0 = time.perf_counter()
|
||||||
self.episode_count = 0
|
self.episode_count = 0
|
||||||
self.episode_returns = None
|
self.episode_returns: Optional[np.ndarray] = None
|
||||||
self.episode_lengths = None
|
self.episode_lengths: Optional[np.ndarray] = None
|
||||||
self.return_queue = deque(maxlen=deque_size)
|
self.return_queue = deque(maxlen=deque_size)
|
||||||
self.length_queue = deque(maxlen=deque_size)
|
self.length_queue = deque(maxlen=deque_size)
|
||||||
self.is_vector_env = getattr(env, "is_vector_env", False)
|
self.is_vector_env = getattr(env, "is_vector_env", False)
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
"""Wrapper for recording videos."""
|
"""Wrapper for recording videos."""
|
||||||
import os
|
import os
|
||||||
from typing import Callable
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
from gym import logger
|
from gym import logger
|
||||||
@@ -77,7 +77,7 @@ class RecordVideo(gym.Wrapper): # TODO: remove with gym 1.0
|
|||||||
|
|
||||||
self.episode_trigger = episode_trigger
|
self.episode_trigger = episode_trigger
|
||||||
self.step_trigger = step_trigger
|
self.step_trigger = step_trigger
|
||||||
self.video_recorder = None
|
self.video_recorder: Optional[video_recorder.VideoRecorder] = None
|
||||||
|
|
||||||
self.video_folder = os.path.abspath(video_folder)
|
self.video_folder = os.path.abspath(video_folder)
|
||||||
# Create output folder if needed
|
# Create output folder if needed
|
||||||
@@ -101,6 +101,7 @@ class RecordVideo(gym.Wrapper): # TODO: remove with gym 1.0
|
|||||||
"""Reset the environment using kwargs and then starts recording if video enabled."""
|
"""Reset the environment using kwargs and then starts recording if video enabled."""
|
||||||
observations = super().reset(**kwargs)
|
observations = super().reset(**kwargs)
|
||||||
if self.recording:
|
if self.recording:
|
||||||
|
assert self.video_recorder is not None
|
||||||
self.video_recorder.capture_frame()
|
self.video_recorder.capture_frame()
|
||||||
self.recorded_frames += 1
|
self.recorded_frames += 1
|
||||||
if self.video_length > 0:
|
if self.video_length > 0:
|
||||||
@@ -148,6 +149,7 @@ class RecordVideo(gym.Wrapper): # TODO: remove with gym 1.0
|
|||||||
self.episode_id += 1
|
self.episode_id += 1
|
||||||
|
|
||||||
if self.recording:
|
if self.recording:
|
||||||
|
assert self.video_recorder is not None
|
||||||
self.video_recorder.capture_frame()
|
self.video_recorder.capture_frame()
|
||||||
self.recorded_frames += 1
|
self.recorded_frames += 1
|
||||||
if self.video_length > 0:
|
if self.video_length > 0:
|
||||||
@@ -168,6 +170,7 @@ class RecordVideo(gym.Wrapper): # TODO: remove with gym 1.0
|
|||||||
def close_video_recorder(self):
|
def close_video_recorder(self):
|
||||||
"""Closes the video recorder if currently recording."""
|
"""Closes the video recorder if currently recording."""
|
||||||
if self.recording:
|
if self.recording:
|
||||||
|
assert self.video_recorder is not None
|
||||||
self.video_recorder.close()
|
self.video_recorder.close()
|
||||||
self.recording = False
|
self.recording = False
|
||||||
self.recorded_frames = 1
|
self.recorded_frames = 1
|
||||||
|
@@ -39,7 +39,10 @@ class ResizeObservation(gym.ObservationWrapper):
|
|||||||
|
|
||||||
self.shape = tuple(shape)
|
self.shape = tuple(shape)
|
||||||
|
|
||||||
obs_shape = self.shape + self.observation_space.shape[2:]
|
assert isinstance(
|
||||||
|
env.observation_space, Box
|
||||||
|
), f"Expected the observation space to be Box, actual type: {type(env.observation_space)}"
|
||||||
|
obs_shape = self.shape + env.observation_space.shape[2:]
|
||||||
self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)
|
self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)
|
||||||
|
|
||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
|
@@ -9,40 +9,7 @@ exclude = [
|
|||||||
"**/node_modules",
|
"**/node_modules",
|
||||||
"**/__pycache__",
|
"**/__pycache__",
|
||||||
|
|
||||||
"gym/envs/box2d/bipedal_walker.py",
|
"gym/envs/mujoco/mujoco_env.py",
|
||||||
"gym/envs/box2d/car_racing.py",
|
|
||||||
|
|
||||||
"gym/spaces/graph.py",
|
|
||||||
|
|
||||||
"gym/envs/mujoco/**",
|
|
||||||
"gym/utils/play.py",
|
|
||||||
|
|
||||||
"gym/vector/async_vector_env.py",
|
|
||||||
"gym/vector/utils/__init__.py",
|
|
||||||
|
|
||||||
"gym/wrappers/atari_preprocessing.py",
|
|
||||||
"gym/wrappers/gray_scale_observation.py",
|
|
||||||
"gym/wrappers/human_rendering.py",
|
|
||||||
"gym/wrappers/normalize.py",
|
|
||||||
"gym/wrappers/pixel_observation.py",
|
|
||||||
"gym/wrappers/record_video.py",
|
|
||||||
"gym/wrappers/monitoring/video_recorder.py",
|
|
||||||
"gym/wrappers/resize_observation.py",
|
|
||||||
|
|
||||||
"tests/envs/test_env_implementation.py",
|
|
||||||
"tests/utils/test_play.py",
|
|
||||||
"tests/vector/test_async_vector_env.py",
|
|
||||||
"tests/vector/test_shared_memory.py",
|
|
||||||
"tests/vector/test_spaces.py",
|
|
||||||
"tests/vector/test_sync_vector_env.py",
|
|
||||||
"tests/vector/test_vector_env.py",
|
|
||||||
"tests/wrappers/test_gray_scale_observation.py",
|
|
||||||
"tests/wrappers/test_order_enforcing.py",
|
|
||||||
"tests/wrappers/test_record_episode_statistics.py",
|
|
||||||
"tests/wrappers/test_resize_observation.py",
|
|
||||||
"tests/wrappers/test_time_aware_observation.py",
|
|
||||||
"tests/wrappers/test_video_recorder.py",
|
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|
||||||
strict = [
|
strict = [
|
||||||
@@ -51,6 +18,7 @@ strict = [
|
|||||||
|
|
||||||
typeCheckingMode = "basic"
|
typeCheckingMode = "basic"
|
||||||
pythonVersion = "3.6"
|
pythonVersion = "3.6"
|
||||||
|
pythonPlatform = "All"
|
||||||
typeshedPath = "typeshed"
|
typeshedPath = "typeshed"
|
||||||
enableTypeIgnoreComments = true
|
enableTypeIgnoreComments = true
|
||||||
|
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
"""Finds all the specs that we can test with"""
|
"""Finds all the specs that we can test with"""
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -18,22 +18,27 @@ def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]:
|
|||||||
|
|
||||||
|
|
||||||
# Tries to make all environment to test with
|
# Tries to make all environment to test with
|
||||||
all_testing_initialised_envs = list(
|
all_testing_initialised_envs: List[Optional[gym.Env]] = [
|
||||||
filter(None, [try_make_env(env_spec) for env_spec in gym.envs.registry.values()])
|
try_make_env(env_spec) for env_spec in gym.envs.registry.values()
|
||||||
)
|
]
|
||||||
|
all_testing_initialised_envs: List[gym.Env] = [
|
||||||
|
env for env in all_testing_initialised_envs if env is not None
|
||||||
|
]
|
||||||
|
|
||||||
# All testing, mujoco and gym environment specs
|
# All testing, mujoco and gym environment specs
|
||||||
all_testing_env_specs = [env.spec for env in all_testing_initialised_envs]
|
all_testing_env_specs: List[EnvSpec] = [
|
||||||
mujoco_testing_env_specs = [
|
env.spec for env in all_testing_initialised_envs
|
||||||
|
]
|
||||||
|
mujoco_testing_env_specs: List[EnvSpec] = [
|
||||||
env_spec
|
env_spec
|
||||||
for env_spec in all_testing_env_specs
|
for env_spec in all_testing_env_specs
|
||||||
if "gym.envs.mujoco" in env_spec.entry_point
|
if "gym.envs.mujoco" in env_spec.entry_point
|
||||||
]
|
]
|
||||||
gym_testing_env_specs = [
|
gym_testing_env_specs: List[EnvSpec] = [
|
||||||
env_spec
|
env_spec
|
||||||
for env_spec in all_testing_env_specs
|
for env_spec in all_testing_env_specs
|
||||||
if any(
|
if any(
|
||||||
f"gym.{ep}" in env_spec.entry_point
|
f"gym.envs.{ep}" in env_spec.entry_point
|
||||||
for ep in ["box2d", "classic_control", "toy_text"]
|
for ep in ["box2d", "classic_control", "toy_text"]
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
@@ -194,7 +194,7 @@ def test_play_loop_real_env():
|
|||||||
|
|
||||||
# first action is 0 because at the first iteration
|
# first action is 0 because at the first iteration
|
||||||
# we can not inject a callback event into play()
|
# we can not inject a callback event into play()
|
||||||
env.step(0)
|
obs, _, _, _ = env.step(0)
|
||||||
for e in keydown_events:
|
for e in keydown_events:
|
||||||
action = keys_to_action[chr(e.key) if str_keys else (e.key,)]
|
action = keys_to_action[chr(e.key) if str_keys else (e.key,)]
|
||||||
obs, _, _, _ = env.step(action)
|
obs, _, _, _ = env.step(action)
|
||||||
|
@@ -17,21 +17,19 @@ from tests.vector.utils import (
|
|||||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||||
def test_create_async_vector_env(shared_memory):
|
def test_create_async_vector_env(shared_memory):
|
||||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||||
try:
|
|
||||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
|
||||||
finally:
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
|
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||||
assert env.num_envs == 8
|
assert env.num_envs == 8
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||||
def test_reset_async_vector_env(shared_memory):
|
def test_reset_async_vector_env(shared_memory):
|
||||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||||
try:
|
|
||||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||||
observations = env.reset()
|
observations = env.reset()
|
||||||
finally:
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
assert isinstance(env.observation_space, Box)
|
assert isinstance(env.observation_space, Box)
|
||||||
@@ -71,7 +69,7 @@ def test_reset_async_vector_env(shared_memory):
|
|||||||
@pytest.mark.parametrize("use_single_action_space", [True, False])
|
@pytest.mark.parametrize("use_single_action_space", [True, False])
|
||||||
def test_step_async_vector_env(shared_memory, use_single_action_space):
|
def test_step_async_vector_env(shared_memory, use_single_action_space):
|
||||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||||
try:
|
|
||||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||||
observations = env.reset()
|
observations = env.reset()
|
||||||
|
|
||||||
@@ -83,7 +81,7 @@ def test_step_async_vector_env(shared_memory, use_single_action_space):
|
|||||||
else:
|
else:
|
||||||
actions = env.action_space.sample()
|
actions = env.action_space.sample()
|
||||||
observations, rewards, dones, _ = env.step(actions)
|
observations, rewards, dones, _ = env.step(actions)
|
||||||
finally:
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
assert isinstance(env.observation_space, Box)
|
assert isinstance(env.observation_space, Box)
|
||||||
@@ -106,12 +104,12 @@ def test_step_async_vector_env(shared_memory, use_single_action_space):
|
|||||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||||
def test_call_async_vector_env(shared_memory):
|
def test_call_async_vector_env(shared_memory):
|
||||||
env_fns = [make_env("CartPole-v1", i, render_mode="rgb_array") for i in range(4)]
|
env_fns = [make_env("CartPole-v1", i, render_mode="rgb_array") for i in range(4)]
|
||||||
try:
|
|
||||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||||
_ = env.reset()
|
_ = env.reset()
|
||||||
images = env.call("render")
|
images = env.call("render")
|
||||||
gravity = env.call("gravity")
|
gravity = env.call("gravity")
|
||||||
finally:
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
assert isinstance(images, tuple)
|
assert isinstance(images, tuple)
|
||||||
@@ -130,59 +128,60 @@ def test_call_async_vector_env(shared_memory):
|
|||||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||||
def test_set_attr_async_vector_env(shared_memory):
|
def test_set_attr_async_vector_env(shared_memory):
|
||||||
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
|
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
|
||||||
try:
|
|
||||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||||
env.set_attr("gravity", [9.81, 3.72, 8.87, 1.62])
|
env.set_attr("gravity", [9.81, 3.72, 8.87, 1.62])
|
||||||
gravity = env.get_attr("gravity")
|
gravity = env.get_attr("gravity")
|
||||||
assert gravity == (9.81, 3.72, 8.87, 1.62)
|
assert gravity == (9.81, 3.72, 8.87, 1.62)
|
||||||
finally:
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||||
def test_copy_async_vector_env(shared_memory):
|
def test_copy_async_vector_env(shared_memory):
|
||||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||||
try:
|
|
||||||
|
# TODO, these tests do nothing, understand the purpose of the tests and fix them
|
||||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=True)
|
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=True)
|
||||||
observations = env.reset()
|
observations = env.reset()
|
||||||
observations[0] = 0
|
observations[0] = 0
|
||||||
finally:
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||||
def test_no_copy_async_vector_env(shared_memory):
|
def test_no_copy_async_vector_env(shared_memory):
|
||||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||||
try:
|
|
||||||
|
# TODO, these tests do nothing, understand the purpose of the tests and fix them
|
||||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=False)
|
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory, copy=False)
|
||||||
observations = env.reset()
|
observations = env.reset()
|
||||||
observations[0] = 0
|
observations[0] = 0
|
||||||
finally:
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||||
def test_reset_timeout_async_vector_env(shared_memory):
|
def test_reset_timeout_async_vector_env(shared_memory):
|
||||||
env_fns = [make_slow_env(0.3, i) for i in range(4)]
|
env_fns = [make_slow_env(0.3, i) for i in range(4)]
|
||||||
with pytest.raises(TimeoutError):
|
|
||||||
try:
|
|
||||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||||
|
with pytest.raises(TimeoutError):
|
||||||
env.reset_async()
|
env.reset_async()
|
||||||
env.reset_wait(timeout=0.1)
|
env.reset_wait(timeout=0.1)
|
||||||
finally:
|
|
||||||
env.close(terminate=True)
|
env.close(terminate=True)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||||
def test_step_timeout_async_vector_env(shared_memory):
|
def test_step_timeout_async_vector_env(shared_memory):
|
||||||
env_fns = [make_slow_env(0.0, i) for i in range(4)]
|
env_fns = [make_slow_env(0.0, i) for i in range(4)]
|
||||||
with pytest.raises(TimeoutError):
|
|
||||||
try:
|
|
||||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||||
|
with pytest.raises(TimeoutError):
|
||||||
env.reset()
|
env.reset()
|
||||||
env.step_async([0.1, 0.1, 0.3, 0.1])
|
env.step_async(np.array([0.1, 0.1, 0.3, 0.1]))
|
||||||
observations, rewards, dones, _ = env.step_wait(timeout=0.1)
|
observations, rewards, dones, _ = env.step_wait(timeout=0.1)
|
||||||
finally:
|
|
||||||
env.close(terminate=True)
|
env.close(terminate=True)
|
||||||
|
|
||||||
|
|
||||||
@@ -190,19 +189,20 @@ def test_step_timeout_async_vector_env(shared_memory):
|
|||||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||||
def test_reset_out_of_order_async_vector_env(shared_memory):
|
def test_reset_out_of_order_async_vector_env(shared_memory):
|
||||||
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
|
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
|
||||||
|
|
||||||
|
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||||
with pytest.raises(NoAsyncCallError):
|
with pytest.raises(NoAsyncCallError):
|
||||||
try:
|
try:
|
||||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
|
||||||
env.reset_wait()
|
env.reset_wait()
|
||||||
except NoAsyncCallError as exception:
|
except NoAsyncCallError as exception:
|
||||||
assert exception.name == "reset"
|
assert exception.name == "reset"
|
||||||
raise
|
raise
|
||||||
finally:
|
|
||||||
env.close(terminate=True)
|
env.close(terminate=True)
|
||||||
|
|
||||||
|
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||||
with pytest.raises(AlreadyPendingCallError):
|
with pytest.raises(AlreadyPendingCallError):
|
||||||
try:
|
try:
|
||||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
|
||||||
actions = env.action_space.sample()
|
actions = env.action_space.sample()
|
||||||
env.reset()
|
env.reset()
|
||||||
env.step_async(actions)
|
env.step_async(actions)
|
||||||
@@ -210,7 +210,7 @@ def test_reset_out_of_order_async_vector_env(shared_memory):
|
|||||||
except NoAsyncCallError as exception:
|
except NoAsyncCallError as exception:
|
||||||
assert exception.name == "step"
|
assert exception.name == "step"
|
||||||
raise
|
raise
|
||||||
finally:
|
|
||||||
env.close(terminate=True)
|
env.close(terminate=True)
|
||||||
|
|
||||||
|
|
||||||
@@ -218,28 +218,29 @@ def test_reset_out_of_order_async_vector_env(shared_memory):
|
|||||||
@pytest.mark.parametrize("shared_memory", [True, False])
|
@pytest.mark.parametrize("shared_memory", [True, False])
|
||||||
def test_step_out_of_order_async_vector_env(shared_memory):
|
def test_step_out_of_order_async_vector_env(shared_memory):
|
||||||
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
|
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
|
||||||
|
|
||||||
|
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||||
with pytest.raises(NoAsyncCallError):
|
with pytest.raises(NoAsyncCallError):
|
||||||
try:
|
try:
|
||||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
env.action_space.sample()
|
||||||
actions = env.action_space.sample()
|
env.reset()
|
||||||
observations = env.reset()
|
env.step_wait()
|
||||||
observations, rewards, dones, infos = env.step_wait()
|
|
||||||
except AlreadyPendingCallError as exception:
|
except AlreadyPendingCallError as exception:
|
||||||
assert exception.name == "step"
|
assert exception.name == "step"
|
||||||
raise
|
raise
|
||||||
finally:
|
|
||||||
env.close(terminate=True)
|
env.close(terminate=True)
|
||||||
|
|
||||||
|
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||||
with pytest.raises(AlreadyPendingCallError):
|
with pytest.raises(AlreadyPendingCallError):
|
||||||
try:
|
try:
|
||||||
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
|
||||||
actions = env.action_space.sample()
|
actions = env.action_space.sample()
|
||||||
env.reset_async()
|
env.reset_async()
|
||||||
env.step_async(actions)
|
env.step_async(actions)
|
||||||
except AlreadyPendingCallError as exception:
|
except AlreadyPendingCallError as exception:
|
||||||
assert exception.name == "reset"
|
assert exception.name == "reset"
|
||||||
raise
|
raise
|
||||||
finally:
|
|
||||||
env.close(terminate=True)
|
env.close(terminate=True)
|
||||||
|
|
||||||
|
|
||||||
@@ -265,7 +266,7 @@ def test_check_spaces_async_vector_env(shared_memory):
|
|||||||
|
|
||||||
def test_custom_space_async_vector_env():
|
def test_custom_space_async_vector_env():
|
||||||
env_fns = [make_custom_space_env(i) for i in range(4)]
|
env_fns = [make_custom_space_env(i) for i in range(4)]
|
||||||
try:
|
|
||||||
env = AsyncVectorEnv(env_fns, shared_memory=False)
|
env = AsyncVectorEnv(env_fns, shared_memory=False)
|
||||||
reset_observations = env.reset()
|
reset_observations = env.reset()
|
||||||
|
|
||||||
@@ -274,7 +275,7 @@ def test_custom_space_async_vector_env():
|
|||||||
|
|
||||||
actions = ("action-2", "action-3", "action-5", "action-7")
|
actions = ("action-2", "action-3", "action-5", "action-7")
|
||||||
step_observations, rewards, dones, _ = env.step(actions)
|
step_observations, rewards, dones, _ = env.step(actions)
|
||||||
finally:
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
assert isinstance(env.single_observation_space, CustomSpace)
|
assert isinstance(env.single_observation_space, CustomSpace)
|
||||||
|
@@ -116,6 +116,7 @@ def test_batch_space_custom_space(space, expected_batch_space_4):
|
|||||||
def test_iterate(space, batch_space):
|
def test_iterate(space, batch_space):
|
||||||
items = batch_space.sample()
|
items = batch_space.sample()
|
||||||
iterator = iterate(batch_space, items)
|
iterator = iterate(batch_space, items)
|
||||||
|
i = 0
|
||||||
for i, item in enumerate(iterator):
|
for i, item in enumerate(iterator):
|
||||||
assert item in space
|
assert item in space
|
||||||
assert i == 3
|
assert i == 3
|
||||||
@@ -129,6 +130,7 @@ def test_iterate(space, batch_space):
|
|||||||
def test_iterate_custom_space(space, batch_space):
|
def test_iterate_custom_space(space, batch_space):
|
||||||
items = batch_space.sample()
|
items = batch_space.sample()
|
||||||
iterator = iterate(batch_space, items)
|
iterator = iterate(batch_space, items)
|
||||||
|
i = 0
|
||||||
for i, item in enumerate(iterator):
|
for i, item in enumerate(iterator):
|
||||||
assert item in space
|
assert item in space
|
||||||
assert i == 3
|
assert i == 3
|
||||||
|
@@ -15,9 +15,7 @@ from tests.vector.utils import (
|
|||||||
|
|
||||||
def test_create_sync_vector_env():
|
def test_create_sync_vector_env():
|
||||||
env_fns = [make_env("FrozenLake-v1", i) for i in range(8)]
|
env_fns = [make_env("FrozenLake-v1", i) for i in range(8)]
|
||||||
try:
|
|
||||||
env = SyncVectorEnv(env_fns)
|
env = SyncVectorEnv(env_fns)
|
||||||
finally:
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
assert env.num_envs == 8
|
assert env.num_envs == 8
|
||||||
@@ -25,10 +23,8 @@ def test_create_sync_vector_env():
|
|||||||
|
|
||||||
def test_reset_sync_vector_env():
|
def test_reset_sync_vector_env():
|
||||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||||
try:
|
|
||||||
env = SyncVectorEnv(env_fns)
|
env = SyncVectorEnv(env_fns)
|
||||||
observations = env.reset()
|
observations = env.reset()
|
||||||
finally:
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
assert isinstance(env.observation_space, Box)
|
assert isinstance(env.observation_space, Box)
|
||||||
@@ -39,10 +35,8 @@ def test_reset_sync_vector_env():
|
|||||||
|
|
||||||
del observations
|
del observations
|
||||||
|
|
||||||
try:
|
|
||||||
env = SyncVectorEnv(env_fns)
|
env = SyncVectorEnv(env_fns)
|
||||||
observations = env.reset(return_info=False)
|
observations = env.reset(return_info=False)
|
||||||
finally:
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
assert isinstance(env.observation_space, Box)
|
assert isinstance(env.observation_space, Box)
|
||||||
@@ -54,10 +48,9 @@ def test_reset_sync_vector_env():
|
|||||||
del observations
|
del observations
|
||||||
|
|
||||||
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
env_fns = [make_env("CartPole-v1", i) for i in range(8)]
|
||||||
try:
|
|
||||||
env = SyncVectorEnv(env_fns)
|
env = SyncVectorEnv(env_fns)
|
||||||
observations, infos = env.reset(return_info=True)
|
observations, infos = env.reset(return_info=True)
|
||||||
finally:
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
assert isinstance(env.observation_space, Box)
|
assert isinstance(env.observation_space, Box)
|
||||||
@@ -72,7 +65,7 @@ def test_reset_sync_vector_env():
|
|||||||
@pytest.mark.parametrize("use_single_action_space", [True, False])
|
@pytest.mark.parametrize("use_single_action_space", [True, False])
|
||||||
def test_step_sync_vector_env(use_single_action_space):
|
def test_step_sync_vector_env(use_single_action_space):
|
||||||
env_fns = [make_env("FrozenLake-v1", i) for i in range(8)]
|
env_fns = [make_env("FrozenLake-v1", i) for i in range(8)]
|
||||||
try:
|
|
||||||
env = SyncVectorEnv(env_fns)
|
env = SyncVectorEnv(env_fns)
|
||||||
observations = env.reset()
|
observations = env.reset()
|
||||||
|
|
||||||
@@ -84,7 +77,7 @@ def test_step_sync_vector_env(use_single_action_space):
|
|||||||
else:
|
else:
|
||||||
actions = env.action_space.sample()
|
actions = env.action_space.sample()
|
||||||
observations, rewards, dones, _ = env.step(actions)
|
observations, rewards, dones, _ = env.step(actions)
|
||||||
finally:
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
assert isinstance(env.observation_space, MultiDiscrete)
|
assert isinstance(env.observation_space, MultiDiscrete)
|
||||||
@@ -106,12 +99,12 @@ def test_step_sync_vector_env(use_single_action_space):
|
|||||||
|
|
||||||
def test_call_sync_vector_env():
|
def test_call_sync_vector_env():
|
||||||
env_fns = [make_env("CartPole-v1", i, render_mode="rgb_array") for i in range(4)]
|
env_fns = [make_env("CartPole-v1", i, render_mode="rgb_array") for i in range(4)]
|
||||||
try:
|
|
||||||
env = SyncVectorEnv(env_fns)
|
env = SyncVectorEnv(env_fns)
|
||||||
_ = env.reset()
|
_ = env.reset()
|
||||||
images = env.call("render")
|
images = env.call("render")
|
||||||
gravity = env.call("gravity")
|
gravity = env.call("gravity")
|
||||||
finally:
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
assert isinstance(images, tuple)
|
assert isinstance(images, tuple)
|
||||||
@@ -129,12 +122,12 @@ def test_call_sync_vector_env():
|
|||||||
|
|
||||||
def test_set_attr_sync_vector_env():
|
def test_set_attr_sync_vector_env():
|
||||||
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
|
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
|
||||||
try:
|
|
||||||
env = SyncVectorEnv(env_fns)
|
env = SyncVectorEnv(env_fns)
|
||||||
env.set_attr("gravity", [9.81, 3.72, 8.87, 1.62])
|
env.set_attr("gravity", [9.81, 3.72, 8.87, 1.62])
|
||||||
gravity = env.get_attr("gravity")
|
gravity = env.get_attr("gravity")
|
||||||
assert gravity == (9.81, 3.72, 8.87, 1.62)
|
assert gravity == (9.81, 3.72, 8.87, 1.62)
|
||||||
finally:
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
@@ -150,7 +143,7 @@ def test_check_spaces_sync_vector_env():
|
|||||||
|
|
||||||
def test_custom_space_sync_vector_env():
|
def test_custom_space_sync_vector_env():
|
||||||
env_fns = [make_custom_space_env(i) for i in range(4)]
|
env_fns = [make_custom_space_env(i) for i in range(4)]
|
||||||
try:
|
|
||||||
env = SyncVectorEnv(env_fns)
|
env = SyncVectorEnv(env_fns)
|
||||||
reset_observations = env.reset()
|
reset_observations = env.reset()
|
||||||
|
|
||||||
@@ -159,7 +152,7 @@ def test_custom_space_sync_vector_env():
|
|||||||
|
|
||||||
actions = ("action-2", "action-3", "action-5", "action-7")
|
actions = ("action-2", "action-3", "action-5", "action-7")
|
||||||
step_observations, rewards, dones, _ = env.step(actions)
|
step_observations, rewards, dones, _ = env.step(actions)
|
||||||
finally:
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
assert isinstance(env.single_observation_space, CustomSpace)
|
assert isinstance(env.single_observation_space, CustomSpace)
|
||||||
|
@@ -12,7 +12,7 @@ from tests.vector.utils import CustomSpace, make_env
|
|||||||
def test_vector_env_equal(shared_memory):
|
def test_vector_env_equal(shared_memory):
|
||||||
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
|
env_fns = [make_env("CartPole-v1", i) for i in range(4)]
|
||||||
num_steps = 100
|
num_steps = 100
|
||||||
try:
|
|
||||||
async_env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
async_env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
|
||||||
sync_env = SyncVectorEnv(env_fns)
|
sync_env = SyncVectorEnv(env_fns)
|
||||||
|
|
||||||
@@ -45,7 +45,6 @@ def test_vector_env_equal(shared_memory):
|
|||||||
assert np.all(async_rewards == sync_rewards)
|
assert np.all(async_rewards == sync_rewards)
|
||||||
assert np.all(async_dones == sync_dones)
|
assert np.all(async_dones == sync_dones)
|
||||||
|
|
||||||
finally:
|
|
||||||
async_env.close()
|
async_env.close()
|
||||||
sync_env.close()
|
sync_env.close()
|
||||||
|
|
||||||
|
@@ -2,6 +2,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
|
from gym import spaces
|
||||||
from gym.wrappers import AtariPreprocessing, GrayScaleObservation
|
from gym.wrappers import AtariPreprocessing, GrayScaleObservation
|
||||||
|
|
||||||
pytest.importorskip("gym.envs.atari")
|
pytest.importorskip("gym.envs.atari")
|
||||||
@@ -20,8 +21,12 @@ def test_gray_scale_observation(env_id, keep_dim):
|
|||||||
gym.make(env_id, disable_env_checker=True), screen_size=84, grayscale_obs=False
|
gym.make(env_id, disable_env_checker=True), screen_size=84, grayscale_obs=False
|
||||||
)
|
)
|
||||||
wrapped_env = GrayScaleObservation(rgb_env, keep_dim=keep_dim)
|
wrapped_env = GrayScaleObservation(rgb_env, keep_dim=keep_dim)
|
||||||
|
|
||||||
|
assert isinstance(rgb_env.observation_space, spaces.Box)
|
||||||
assert rgb_env.observation_space.shape[-1] == 3
|
assert rgb_env.observation_space.shape[-1] == 3
|
||||||
|
|
||||||
|
assert isinstance(wrapped_env.observation_space, spaces.Box)
|
||||||
|
|
||||||
seed = 0
|
seed = 0
|
||||||
|
|
||||||
gray_obs = gray_env.reset(seed=seed)
|
gray_obs = gray_env.reset(seed=seed)
|
||||||
|
@@ -26,16 +26,16 @@ def test_order_enforcing():
|
|||||||
|
|
||||||
# Assert that the order enforcing works for step and render before reset
|
# Assert that the order enforcing works for step and render before reset
|
||||||
order_enforced_env = OrderEnforcing(env)
|
order_enforced_env = OrderEnforcing(env)
|
||||||
assert order_enforced_env._has_reset is False
|
assert order_enforced_env.has_reset is False
|
||||||
with pytest.raises(ResetNeeded):
|
with pytest.raises(ResetNeeded):
|
||||||
order_enforced_env.step(0)
|
order_enforced_env.step(0)
|
||||||
with pytest.raises(ResetNeeded):
|
with pytest.raises(ResetNeeded):
|
||||||
order_enforced_env.render(mode="rgb_array")
|
order_enforced_env.render(mode="rgb_array")
|
||||||
assert order_enforced_env._has_reset is False
|
assert order_enforced_env.has_reset is False
|
||||||
|
|
||||||
# Assert that the Assertion errors are not raised after reset
|
# Assert that the Assertion errors are not raised after reset
|
||||||
order_enforced_env.reset()
|
order_enforced_env.reset()
|
||||||
assert order_enforced_env._has_reset is True
|
assert order_enforced_env.has_reset is True
|
||||||
order_enforced_env.step(0)
|
order_enforced_env.step(0)
|
||||||
order_enforced_env.render(mode="rgb_array")
|
order_enforced_env.render(mode="rgb_array")
|
||||||
|
|
||||||
|
@@ -14,6 +14,7 @@ def test_record_episode_statistics(env_id, deque_size):
|
|||||||
|
|
||||||
for n in range(5):
|
for n in range(5):
|
||||||
env.reset()
|
env.reset()
|
||||||
|
assert env.episode_returns is not None and env.episode_lengths is not None
|
||||||
assert env.episode_returns[0] == 0.0
|
assert env.episode_returns[0] == 0.0
|
||||||
assert env.episode_lengths[0] == 0
|
assert env.episode_lengths[0] == 0
|
||||||
for t in range(env.spec.max_episode_steps):
|
for t in range(env.spec.max_episode_steps):
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
|
from gym import spaces
|
||||||
from gym.wrappers import ResizeObservation
|
from gym.wrappers import ResizeObservation
|
||||||
|
|
||||||
pytest.importorskip("gym.envs.atari")
|
pytest.importorskip("gym.envs.atari")
|
||||||
@@ -14,6 +15,7 @@ def test_resize_observation(env_id, shape):
|
|||||||
env = gym.make(env_id, disable_env_checker=True)
|
env = gym.make(env_id, disable_env_checker=True)
|
||||||
env = ResizeObservation(env, shape)
|
env = ResizeObservation(env, shape)
|
||||||
|
|
||||||
|
assert isinstance(env.observation_space, spaces.Box)
|
||||||
assert env.observation_space.shape[-1] == 3
|
assert env.observation_space.shape[-1] == 3
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
if isinstance(shape, int):
|
if isinstance(shape, int):
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
|
from gym import spaces
|
||||||
from gym.wrappers import TimeAwareObservation
|
from gym.wrappers import TimeAwareObservation
|
||||||
|
|
||||||
|
|
||||||
@@ -9,6 +10,8 @@ def test_time_aware_observation(env_id):
|
|||||||
env = gym.make(env_id, disable_env_checker=True)
|
env = gym.make(env_id, disable_env_checker=True)
|
||||||
wrapped_env = TimeAwareObservation(env)
|
wrapped_env = TimeAwareObservation(env)
|
||||||
|
|
||||||
|
assert isinstance(env.observation_space, spaces.Box)
|
||||||
|
assert isinstance(wrapped_env.observation_space, spaces.Box)
|
||||||
assert wrapped_env.observation_space.shape[0] == env.observation_space.shape[0] + 1
|
assert wrapped_env.observation_space.shape[0] == env.observation_space.shape[0] + 1
|
||||||
|
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
@@ -33,9 +33,10 @@ def test_record_simple():
|
|||||||
rec = VideoRecorder(env)
|
rec = VideoRecorder(env)
|
||||||
env.reset()
|
env.reset()
|
||||||
rec.capture_frame()
|
rec.capture_frame()
|
||||||
|
assert rec.encoder is not None
|
||||||
proc = rec.encoder.proc
|
proc = rec.encoder.proc
|
||||||
|
|
||||||
assert proc.poll() is None # subprocess is running
|
assert proc is not None and proc.poll() is None # subprocess is running
|
||||||
|
|
||||||
rec.close()
|
rec.close()
|
||||||
|
|
||||||
@@ -55,9 +56,10 @@ def test_autoclose():
|
|||||||
rec.capture_frame()
|
rec.capture_frame()
|
||||||
|
|
||||||
rec_path = rec.path
|
rec_path = rec.path
|
||||||
|
assert rec.encoder is not None
|
||||||
proc = rec.encoder.proc
|
proc = rec.encoder.proc
|
||||||
|
|
||||||
assert proc.poll() is None # subprocess is running
|
assert proc is not None and proc.poll() is None # subprocess is running
|
||||||
|
|
||||||
# The function ends without an explicit `rec.close()` call
|
# The function ends without an explicit `rec.close()` call
|
||||||
# The Python interpreter will implicitly do `del rec` on garbage cleaning
|
# The Python interpreter will implicitly do `del rec` on garbage cleaning
|
||||||
@@ -68,7 +70,7 @@ def test_autoclose():
|
|||||||
gc.collect() # do explicit garbage collection for test
|
gc.collect() # do explicit garbage collection for test
|
||||||
time.sleep(5) # wait for subprocess exiting
|
time.sleep(5) # wait for subprocess exiting
|
||||||
|
|
||||||
assert proc.poll() is not None # subprocess is terminated
|
assert proc is not None and proc.poll() is not None # subprocess is terminated
|
||||||
assert os.path.exists(rec_path)
|
assert os.path.exists(rec_path)
|
||||||
f = open(rec_path)
|
f = open(rec_path)
|
||||||
assert os.fstat(f.fileno()).st_size > 100
|
assert os.fstat(f.fileno()).st_size > 100
|
||||||
|
Reference in New Issue
Block a user