mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 22:04:31 +00:00
Remove old Render API (#3027)
* init * add .gitignore * fix .gitignore * remove internal backend use * fix VideoRecorder test * fix .gitignore * fix order enforcing tests * adapt play.py * reformat * fix .gitignore * add type to DummyPlayEnv
This commit is contained in:
43
gym/core.py
43
gym/core.py
@@ -32,35 +32,6 @@ ActType = TypeVar("ActType")
|
|||||||
RenderFrame = TypeVar("RenderFrame")
|
RenderFrame = TypeVar("RenderFrame")
|
||||||
|
|
||||||
|
|
||||||
# TODO: remove with gym 1.0
|
|
||||||
def _deprecate_mode(render_func): # type: ignore
|
|
||||||
"""Wrapper used for adding deprecation warning to the mode kwarg in the render method."""
|
|
||||||
render_return = Optional[Union[RenderFrame, List[RenderFrame]]]
|
|
||||||
|
|
||||||
def render(
|
|
||||||
self: object, *args: Tuple[Any], **kwargs: Dict[str, Any]
|
|
||||||
) -> render_return:
|
|
||||||
if "mode" in kwargs.keys() or len(args) > 0:
|
|
||||||
deprecation(
|
|
||||||
"The argument mode in render method is deprecated; "
|
|
||||||
"use render_mode during environment initialization instead.\n"
|
|
||||||
"See here for more information: https://www.gymlibrary.ml/content/api/"
|
|
||||||
)
|
|
||||||
elif self.spec is not None and "render_mode" not in self.spec.kwargs.keys(): # type: ignore
|
|
||||||
deprecation(
|
|
||||||
"You are calling render method, "
|
|
||||||
"but you didn't specified the argument render_mode at environment initialization. "
|
|
||||||
"To maintain backward compatibility, the environment will render in human mode.\n"
|
|
||||||
"If you want to render in human mode, initialize the environment in this way: "
|
|
||||||
"gym.make('EnvName', render_mode='human') and don't call the render method.\n"
|
|
||||||
"See here for more information: https://www.gymlibrary.ml/content/api/"
|
|
||||||
)
|
|
||||||
|
|
||||||
return render_func(self, *args, **kwargs)
|
|
||||||
|
|
||||||
return render
|
|
||||||
|
|
||||||
|
|
||||||
class Env(Generic[ObsType, ActType]):
|
class Env(Generic[ObsType, ActType]):
|
||||||
r"""The main OpenAI Gym class.
|
r"""The main OpenAI Gym class.
|
||||||
|
|
||||||
@@ -88,12 +59,6 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
Note: a default reward range set to :math:`(-\infty,+\infty)` already exists. Set it if you want a narrower range.
|
Note: a default reward range set to :math:`(-\infty,+\infty)` already exists. Set it if you want a narrower range.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init_subclass__(cls) -> None:
|
|
||||||
"""Hook used for wrapping render method."""
|
|
||||||
super().__init_subclass__()
|
|
||||||
if "render" in vars(cls):
|
|
||||||
cls.render = _deprecate_mode(vars(cls)["render"])
|
|
||||||
|
|
||||||
# Set this in SOME subclasses
|
# Set this in SOME subclasses
|
||||||
metadata: Dict[str, Any] = {"render_modes": []}
|
metadata: Dict[str, Any] = {"render_modes": []}
|
||||||
# define render_mode if your environment supports rendering
|
# define render_mode if your environment supports rendering
|
||||||
@@ -195,8 +160,7 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
if seed is not None:
|
if seed is not None:
|
||||||
self._np_random, seed = seeding.np_random(seed)
|
self._np_random, seed = seeding.np_random(seed)
|
||||||
|
|
||||||
# TODO: remove kwarg mode with gym 1.0
|
def render(self) -> Optional[Union[RenderFrame, List[RenderFrame]]]:
|
||||||
def render(self, mode="human") -> Optional[Union[RenderFrame, List[RenderFrame]]]:
|
|
||||||
"""Compute the render frames as specified by render_mode attribute during initialization of the environment.
|
"""Compute the render frames as specified by render_mode attribute during initialization of the environment.
|
||||||
|
|
||||||
The set of supported modes varies per environment. (And some
|
The set of supported modes varies per environment. (And some
|
||||||
@@ -214,11 +178,6 @@ class Env(Generic[ObsType, ActType]):
|
|||||||
terminal-style text representation for each time step.
|
terminal-style text representation for each time step.
|
||||||
The text can include newlines and ANSI escape sequences (e.g. for colors).
|
The text can include newlines and ANSI escape sequences (e.g. for colors).
|
||||||
|
|
||||||
Note:
|
|
||||||
Rendering computations is performed internally even if you don't call render().
|
|
||||||
To avoid this, you can set render_mode = None and, if the environment supports it,
|
|
||||||
call render() specifying the argument 'mode'.
|
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Make sure that your class's metadata 'render_modes' key includes
|
Make sure that your class's metadata 'render_modes' key includes
|
||||||
the list of supported modes. It's recommended to call super()
|
the list of supported modes. It's recommended to call super()
|
||||||
|
@@ -608,11 +608,8 @@ class BipedalWalker(gym.Env, EzPickle):
|
|||||||
self.renderer.render_step()
|
self.renderer.render_step()
|
||||||
return np.array(state, dtype=np.float32), reward, terminated, False, {}
|
return np.array(state, dtype=np.float32), reward, terminated, False, {}
|
||||||
|
|
||||||
def render(self, mode: str = "human"):
|
def render(self):
|
||||||
if self.render_mode is not None:
|
return self.renderer.get_renders()
|
||||||
return self.renderer.get_renders()
|
|
||||||
else:
|
|
||||||
return self._render(mode)
|
|
||||||
|
|
||||||
def _render(self, mode: str = "human"):
|
def _render(self, mode: str = "human"):
|
||||||
assert mode in self.metadata["render_modes"]
|
assert mode in self.metadata["render_modes"]
|
||||||
|
@@ -570,11 +570,8 @@ class CarRacing(gym.Env, EzPickle):
|
|||||||
self.renderer.render_step()
|
self.renderer.render_step()
|
||||||
return self.state, step_reward, terminated, truncated, {}
|
return self.state, step_reward, terminated, truncated, {}
|
||||||
|
|
||||||
def render(self, mode: str = "human"):
|
def render(self):
|
||||||
if self.render_mode is not None:
|
return self.renderer.get_renders()
|
||||||
return self.renderer.get_renders()
|
|
||||||
else:
|
|
||||||
return self._render(mode)
|
|
||||||
|
|
||||||
def _render(self, mode: str = "human"):
|
def _render(self, mode: str = "human"):
|
||||||
assert mode in self.metadata["render_modes"]
|
assert mode in self.metadata["render_modes"]
|
||||||
|
@@ -596,11 +596,8 @@ class LunarLander(gym.Env, EzPickle):
|
|||||||
self.renderer.render_step()
|
self.renderer.render_step()
|
||||||
return np.array(state, dtype=np.float32), reward, terminated, False, {}
|
return np.array(state, dtype=np.float32), reward, terminated, False, {}
|
||||||
|
|
||||||
def render(self, mode="human"):
|
def render(self):
|
||||||
if self.render_mode is not None:
|
return self.renderer.get_renders()
|
||||||
return self.renderer.get_renders()
|
|
||||||
else:
|
|
||||||
return self._render(mode)
|
|
||||||
|
|
||||||
def _render(self, mode="human"):
|
def _render(self, mode="human"):
|
||||||
assert mode in self.metadata["render_modes"]
|
assert mode in self.metadata["render_modes"]
|
||||||
|
@@ -286,11 +286,8 @@ class AcrobotEnv(core.Env):
|
|||||||
ddtheta1 = -(d2 * ddtheta2 + phi1) / d1
|
ddtheta1 = -(d2 * ddtheta2 + phi1) / d1
|
||||||
return dtheta1, dtheta2, ddtheta1, ddtheta2, 0.0
|
return dtheta1, dtheta2, ddtheta1, ddtheta2, 0.0
|
||||||
|
|
||||||
def render(self, mode="human"):
|
def render(self):
|
||||||
if self.render_mode is not None:
|
return self.renderer.get_renders()
|
||||||
return self.renderer.get_renders()
|
|
||||||
else:
|
|
||||||
return self._render(mode)
|
|
||||||
|
|
||||||
def _render(self, mode="human"):
|
def _render(self, mode="human"):
|
||||||
assert mode in self.metadata["render_modes"]
|
assert mode in self.metadata["render_modes"]
|
||||||
|
@@ -210,11 +210,8 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
|||||||
else:
|
else:
|
||||||
return np.array(self.state, dtype=np.float32), {}
|
return np.array(self.state, dtype=np.float32), {}
|
||||||
|
|
||||||
def render(self, mode="human"):
|
def render(self):
|
||||||
if self.render_mode is not None:
|
return self.renderer.get_renders()
|
||||||
return self.renderer.get_renders()
|
|
||||||
else:
|
|
||||||
return self._render(mode)
|
|
||||||
|
|
||||||
def _render(self, mode="human"):
|
def _render(self, mode="human"):
|
||||||
assert mode in self.metadata["render_modes"]
|
assert mode in self.metadata["render_modes"]
|
||||||
|
@@ -196,11 +196,8 @@ class Continuous_MountainCarEnv(gym.Env):
|
|||||||
def _height(self, xs):
|
def _height(self, xs):
|
||||||
return np.sin(3 * xs) * 0.45 + 0.55
|
return np.sin(3 * xs) * 0.45 + 0.55
|
||||||
|
|
||||||
def render(self, mode="human"):
|
def render(self):
|
||||||
if self.render_mode is not None:
|
return self.renderer.get_renders()
|
||||||
return self.renderer.get_renders()
|
|
||||||
else:
|
|
||||||
return self._render(mode)
|
|
||||||
|
|
||||||
def _render(self, mode="human"):
|
def _render(self, mode="human"):
|
||||||
assert mode in self.metadata["render_modes"]
|
assert mode in self.metadata["render_modes"]
|
||||||
|
@@ -170,11 +170,8 @@ class MountainCarEnv(gym.Env):
|
|||||||
def _height(self, xs):
|
def _height(self, xs):
|
||||||
return np.sin(3 * xs) * 0.45 + 0.55
|
return np.sin(3 * xs) * 0.45 + 0.55
|
||||||
|
|
||||||
def render(self, mode="human"):
|
def render(self):
|
||||||
if self.render_mode is not None:
|
return self.renderer.get_renders()
|
||||||
return self.renderer.get_renders()
|
|
||||||
else:
|
|
||||||
return self._render(mode)
|
|
||||||
|
|
||||||
def _render(self, mode="human"):
|
def _render(self, mode="human"):
|
||||||
assert mode in self.metadata["render_modes"]
|
assert mode in self.metadata["render_modes"]
|
||||||
|
@@ -171,11 +171,8 @@ class PendulumEnv(gym.Env):
|
|||||||
theta, thetadot = self.state
|
theta, thetadot = self.state
|
||||||
return np.array([np.cos(theta), np.sin(theta), thetadot], dtype=np.float32)
|
return np.array([np.cos(theta), np.sin(theta), thetadot], dtype=np.float32)
|
||||||
|
|
||||||
def render(self, mode="human"):
|
def render(self):
|
||||||
if self.render_mode is not None:
|
return self.renderer.get_renders()
|
||||||
return self.renderer.get_renders()
|
|
||||||
else:
|
|
||||||
return self._render(mode)
|
|
||||||
|
|
||||||
def _render(self, mode="human"):
|
def _render(self, mode="human"):
|
||||||
assert mode in self.metadata["render_modes"]
|
assert mode in self.metadata["render_modes"]
|
||||||
|
@@ -176,32 +176,8 @@ class BaseMujocoEnv(gym.Env):
|
|||||||
raise ValueError("Action dimension mismatch")
|
raise ValueError("Action dimension mismatch")
|
||||||
self._step_mujoco_simulation(ctrl, n_frames)
|
self._step_mujoco_simulation(ctrl, n_frames)
|
||||||
|
|
||||||
def render(
|
def render(self):
|
||||||
self,
|
return self.renderer.get_renders()
|
||||||
mode: str = "human",
|
|
||||||
width: Optional[int] = None,
|
|
||||||
height: Optional[int] = None,
|
|
||||||
camera_id: Optional[int] = None,
|
|
||||||
camera_name: Optional[str] = None,
|
|
||||||
):
|
|
||||||
if self.render_mode is not None:
|
|
||||||
assert (
|
|
||||||
width is None
|
|
||||||
and height is None
|
|
||||||
and camera_id is None
|
|
||||||
and camera_name is None
|
|
||||||
), "Unexpected argument for render. Specify render arguments at environment initialization."
|
|
||||||
return self.renderer.get_renders()
|
|
||||||
else:
|
|
||||||
width = width if width is not None else DEFAULT_SIZE
|
|
||||||
height = height if height is not None else DEFAULT_SIZE
|
|
||||||
return self._render(
|
|
||||||
mode=mode,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
camera_id=camera_id,
|
|
||||||
camera_name=camera_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if self.viewer is not None:
|
if self.viewer is not None:
|
||||||
|
@@ -194,13 +194,10 @@ class BlackjackEnv(gym.Env):
|
|||||||
else:
|
else:
|
||||||
return self._get_obs(), {}
|
return self._get_obs(), {}
|
||||||
|
|
||||||
def render(self, mode="human"):
|
def render(self):
|
||||||
if self.render_mode is not None:
|
return self.renderer.get_renders()
|
||||||
return self.renderer.get_renders()
|
|
||||||
else:
|
|
||||||
return self._render(mode)
|
|
||||||
|
|
||||||
def _render(self, mode):
|
def _render(self, mode: str = "human"):
|
||||||
assert mode in self.metadata["render_modes"]
|
assert mode in self.metadata["render_modes"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@@ -166,11 +166,8 @@ class CliffWalkingEnv(Env):
|
|||||||
else:
|
else:
|
||||||
return int(self.s), {"prob": 1}
|
return int(self.s), {"prob": 1}
|
||||||
|
|
||||||
def render(self, mode="human"):
|
def render(self):
|
||||||
if self.render_mode is not None:
|
return self.renderer.get_renders()
|
||||||
return self.renderer.get_renders()
|
|
||||||
else:
|
|
||||||
return self._render(mode)
|
|
||||||
|
|
||||||
def _render(self, mode="human"):
|
def _render(self, mode="human"):
|
||||||
if mode == "ansi":
|
if mode == "ansi":
|
||||||
|
@@ -271,11 +271,8 @@ class FrozenLakeEnv(Env):
|
|||||||
else:
|
else:
|
||||||
return int(self.s), {"prob": 1}
|
return int(self.s), {"prob": 1}
|
||||||
|
|
||||||
def render(self, mode="human"):
|
def render(self):
|
||||||
if self.render_mode is not None:
|
return self.renderer.get_renders()
|
||||||
return self.renderer.get_renders()
|
|
||||||
else:
|
|
||||||
return self._render(mode)
|
|
||||||
|
|
||||||
def _render(self, mode="human"):
|
def _render(self, mode="human"):
|
||||||
assert mode in self.metadata["render_modes"]
|
assert mode in self.metadata["render_modes"]
|
||||||
|
@@ -280,11 +280,8 @@ class TaxiEnv(Env):
|
|||||||
else:
|
else:
|
||||||
return int(self.s), {"prob": 1.0, "action_mask": self.action_mask(self.s)}
|
return int(self.s), {"prob": 1.0, "action_mask": self.action_mask(self.s)}
|
||||||
|
|
||||||
def render(self, mode="human"):
|
def render(self):
|
||||||
if self.render_mode is not None:
|
return self.renderer.get_renders()
|
||||||
return self.renderer.get_renders()
|
|
||||||
else:
|
|
||||||
return self._render(mode)
|
|
||||||
|
|
||||||
def _render(self, mode):
|
def _render(self, mode):
|
||||||
assert mode in self.metadata["render_modes"]
|
assert mode in self.metadata["render_modes"]
|
||||||
|
@@ -1,7 +1,4 @@
|
|||||||
"""Utilities of visualising an environment."""
|
"""Utilities of visualising an environment."""
|
||||||
|
|
||||||
# TODO: Convert to new step API in 1.0
|
|
||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -53,6 +50,12 @@ class PlayableGame:
|
|||||||
keys_to_action: The dictionary of keyboard tuples and action value
|
keys_to_action: The dictionary of keyboard tuples and action value
|
||||||
zoom: If to zoom in on the environment render
|
zoom: If to zoom in on the environment render
|
||||||
"""
|
"""
|
||||||
|
if env.render_mode not in {"rgb_array", "single_rgb_array"}:
|
||||||
|
logger.error(
|
||||||
|
"PlayableGame wrapper works only with rgb_array and single_rgb_array render modes, "
|
||||||
|
f"but your environment render_mode = {env.render_mode}."
|
||||||
|
)
|
||||||
|
|
||||||
self.env = env
|
self.env = env
|
||||||
self.relevant_keys = self._get_relevant_keys(keys_to_action)
|
self.relevant_keys = self._get_relevant_keys(keys_to_action)
|
||||||
self.video_size = self._get_video_size(zoom)
|
self.video_size = self._get_video_size(zoom)
|
||||||
@@ -78,8 +81,9 @@ class PlayableGame:
|
|||||||
return relevant_keys
|
return relevant_keys
|
||||||
|
|
||||||
def _get_video_size(self, zoom: Optional[float] = None) -> Tuple[int, int]:
|
def _get_video_size(self, zoom: Optional[float] = None) -> Tuple[int, int]:
|
||||||
# TODO: this needs to be updated when the render API change goes through
|
rendered = self.env.render()
|
||||||
rendered = self.env.render(mode="rgb_array")
|
if isinstance(rendered, List):
|
||||||
|
rendered = rendered[-1]
|
||||||
assert rendered is not None and isinstance(rendered, np.ndarray)
|
assert rendered is not None and isinstance(rendered, np.ndarray)
|
||||||
video_size = [rendered.shape[1], rendered.shape[0]]
|
video_size = [rendered.shape[1], rendered.shape[0]]
|
||||||
|
|
||||||
@@ -146,7 +150,8 @@ def play(
|
|||||||
|
|
||||||
>>> import gym
|
>>> import gym
|
||||||
>>> from gym.utils.play import play
|
>>> from gym.utils.play import play
|
||||||
>>> play(gym.make("CarRacing-v1"), keys_to_action={"w": np.array([0, 0.7, 0]),
|
>>> play(gym.make("CarRacing-v1", render_mode="single_rgb_array"), keys_to_action={
|
||||||
|
... "w": np.array([0, 0.7, 0]),
|
||||||
... "a": np.array([-1, 0, 0]),
|
... "a": np.array([-1, 0, 0]),
|
||||||
... "s": np.array([0, 0, 1]),
|
... "s": np.array([0, 0, 1]),
|
||||||
... "d": np.array([1, 0, 0]),
|
... "d": np.array([1, 0, 0]),
|
||||||
@@ -214,6 +219,11 @@ def play(
|
|||||||
deprecation(
|
deprecation(
|
||||||
"`play.py` currently supports only the old step API which returns one boolean, however this will soon be updated to support only the new step api that returns two bools."
|
"`play.py` currently supports only the old step API which returns one boolean, however this will soon be updated to support only the new step api that returns two bools."
|
||||||
)
|
)
|
||||||
|
if env.render_mode not in {"rgb_array", "single_rgb_array"}:
|
||||||
|
logger.error(
|
||||||
|
"play method works only with rgb_array and single_rgb_array render modes, "
|
||||||
|
f"but your environment render_mode = {env.render_mode}."
|
||||||
|
)
|
||||||
|
|
||||||
env.reset(seed=seed)
|
env.reset(seed=seed)
|
||||||
|
|
||||||
@@ -255,8 +265,10 @@ def play(
|
|||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback(prev_obs, obs, action, rew, done, info)
|
callback(prev_obs, obs, action, rew, done, info)
|
||||||
if obs is not None:
|
if obs is not None:
|
||||||
# TODO: this needs to be updated when the render API change goes through
|
rendered = env.render()
|
||||||
rendered = env.render(mode="rgb_array")
|
if isinstance(rendered, List):
|
||||||
|
rendered = rendered[-1]
|
||||||
|
assert rendered is not None and isinstance(rendered, np.ndarray)
|
||||||
display_arr(
|
display_arr(
|
||||||
game.screen, rendered, transpose=transpose, video_size=game.video_size
|
game.screen, rendered, transpose=transpose, video_size=game.video_size
|
||||||
)
|
)
|
||||||
|
@@ -2,24 +2,20 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import os.path
|
import os.path
|
||||||
import pkgutil
|
|
||||||
import shutil
|
|
||||||
import subprocess
|
|
||||||
import tempfile
|
import tempfile
|
||||||
from io import StringIO
|
from typing import List, Optional
|
||||||
from typing import List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from gym import error, logger
|
from gym import error, logger
|
||||||
|
|
||||||
|
try:
|
||||||
def touch(path: str):
|
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
|
||||||
"""Touch a filename at path."""
|
except ImportError:
|
||||||
open(path, "a").close()
|
raise error.DependencyNotInstalled(
|
||||||
|
"MoviePy is not installed, run `pip install moviepy`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VideoRecorder: # TODO: remove with gym 1.0
|
class VideoRecorder:
|
||||||
"""VideoRecorder renders a nice movie of a rollout, frame by frame.
|
"""VideoRecorder renders a nice movie of a rollout, frame by frame.
|
||||||
|
|
||||||
It comes with an ``enabled`` option, so you can still use the same code on episodes where you don't want to record video.
|
It comes with an ``enabled`` option, so you can still use the same code on episodes where you don't want to record video.
|
||||||
@@ -35,7 +31,6 @@ class VideoRecorder: # TODO: remove with gym 1.0
|
|||||||
metadata: Optional[dict] = None,
|
metadata: Optional[dict] = None,
|
||||||
enabled: bool = True,
|
enabled: bool = True,
|
||||||
base_path: Optional[str] = None,
|
base_path: Optional[str] = None,
|
||||||
internal_backend_use: bool = False,
|
|
||||||
):
|
):
|
||||||
"""Video recorder renders a nice movie of a rollout, frame by frame.
|
"""Video recorder renders a nice movie of a rollout, frame by frame.
|
||||||
|
|
||||||
@@ -55,124 +50,55 @@ class VideoRecorder: # TODO: remove with gym 1.0
|
|||||||
self._closed = False
|
self._closed = False
|
||||||
|
|
||||||
self.render_history = []
|
self.render_history = []
|
||||||
self.last_frame = None
|
|
||||||
self.env = env
|
self.env = env
|
||||||
|
|
||||||
self.render_mode = env.render_mode
|
self.render_mode = env.render_mode
|
||||||
modes = env.metadata.get("render_modes", [])
|
|
||||||
|
|
||||||
# backward-compatibility mode:
|
|
||||||
backward_compatible_mode = env.metadata.get("render.modes", [])
|
|
||||||
if len(modes) == 0 and len(backward_compatible_mode) > 0:
|
|
||||||
logger.deprecation(
|
|
||||||
'`env.metadata["render.modes"] is marked as deprecated and will be replaced '
|
|
||||||
'with `env.metadata["render_modes"]` see https://github.com/openai/gym/pull/2654 for more details'
|
|
||||||
)
|
|
||||||
modes = backward_compatible_mode
|
|
||||||
|
|
||||||
self.ansi_mode = False
|
|
||||||
if "rgb_array" != self.render_mode and "single_rgb_array" != self.render_mode:
|
if "rgb_array" != self.render_mode and "single_rgb_array" != self.render_mode:
|
||||||
if self.render_mode is None and (
|
logger.warn(
|
||||||
"single_rgb_array" in modes or "rgb_array" in modes
|
f"Disabling video recorder because environment {env} was not initialized with any compatible video "
|
||||||
):
|
"mode between `single_rgb_array` and `rgb_array`"
|
||||||
logger.deprecation(
|
)
|
||||||
f"Recording ability for environment {env.spec.id} initialized with `render_mode=None` is marked "
|
# Disable since the environment has not been initialized with a compatible `render_mode`
|
||||||
"as deprecated and will be removed in the future."
|
self.enabled = False
|
||||||
)
|
|
||||||
elif "ansi" == env.render_mode:
|
|
||||||
self.ansi_mode = True
|
|
||||||
logger.deprecation(
|
|
||||||
f'Recording ability for environment {env} initialized with `render_mode="ansi"` is marked '
|
|
||||||
"as deprecated and will be removed in the future."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warn(
|
|
||||||
f"Disabling video recorder because environment {env} was not initialized with any compatible video "
|
|
||||||
"mode between `single_rgb_array` and `rgb_array`"
|
|
||||||
)
|
|
||||||
# Disable since the environment has not been initialized with a compatible `render_mode`
|
|
||||||
self.enabled = False
|
|
||||||
|
|
||||||
# Don't bother setting anything else if not enabled
|
# Don't bother setting anything else if not enabled
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not internal_backend_use:
|
|
||||||
logger.deprecation(
|
|
||||||
f"{self.__class__} is marked as deprecated and will be removed in the future."
|
|
||||||
)
|
|
||||||
|
|
||||||
if path is not None and base_path is not None:
|
if path is not None and base_path is not None:
|
||||||
raise error.Error("You can pass at most one of `path` or `base_path`.")
|
raise error.Error("You can pass at most one of `path` or `base_path`.")
|
||||||
|
|
||||||
required_ext = ".json" if self.ansi_mode else ".mp4"
|
required_ext = ".mp4"
|
||||||
if path is None:
|
if path is None:
|
||||||
if base_path is not None:
|
if base_path is not None:
|
||||||
# Base path given, append ext
|
# Base path given, append ext
|
||||||
path = base_path + required_ext
|
path = base_path + required_ext
|
||||||
else:
|
else:
|
||||||
# Otherwise, just generate a unique filename
|
# Otherwise, just generate a unique filename
|
||||||
with tempfile.NamedTemporaryFile(
|
with tempfile.NamedTemporaryFile(suffix=required_ext) as f:
|
||||||
suffix=required_ext, delete=False
|
|
||||||
) as f:
|
|
||||||
path = f.name
|
path = f.name
|
||||||
self.path = path
|
self.path = path
|
||||||
|
|
||||||
path_base, actual_ext = os.path.splitext(self.path)
|
path_base, actual_ext = os.path.splitext(self.path)
|
||||||
|
|
||||||
if actual_ext != required_ext:
|
if actual_ext != required_ext:
|
||||||
if self.ansi_mode:
|
|
||||||
hint = (
|
|
||||||
" HINT: The environment is text-only, "
|
|
||||||
"therefore we're recording its text output in a structured JSON format."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hint = ""
|
|
||||||
raise error.Error(
|
raise error.Error(
|
||||||
f"Invalid path given: {self.path} -- must have file extension {required_ext}.{hint}"
|
f"Invalid path given: {self.path} -- must have file extension {required_ext}."
|
||||||
)
|
)
|
||||||
# Touch the file in any case, so we know it's present. This corrects for platform platform differences.
|
|
||||||
# Using ffmpeg on OS X, the file is precreated, but not on Linux.
|
|
||||||
touch(path)
|
|
||||||
|
|
||||||
self.frames_per_sec = env.metadata.get("render_fps", 30)
|
self.frames_per_sec = env.metadata.get("render_fps", 30)
|
||||||
self.output_frames_per_sec = env.metadata.get("render_fps", self.frames_per_sec)
|
|
||||||
|
|
||||||
# backward-compatibility mode:
|
|
||||||
self.backward_compatible_frames_per_sec = env.metadata.get(
|
|
||||||
"video.frames_per_second", self.frames_per_sec
|
|
||||||
)
|
|
||||||
self.backward_compatible_output_frames_per_sec = env.metadata.get(
|
|
||||||
"video.output_frames_per_second", self.output_frames_per_sec
|
|
||||||
)
|
|
||||||
if self.frames_per_sec != self.backward_compatible_frames_per_sec:
|
|
||||||
logger.deprecation(
|
|
||||||
'`env.metadata["video.frames_per_second"] is marked as deprecated and will be replaced '
|
|
||||||
'with `env.metadata["render_fps"]` see https://github.com/openai/gym/pull/2654 for more details'
|
|
||||||
)
|
|
||||||
self.frames_per_sec = self.backward_compatible_frames_per_sec
|
|
||||||
if self.output_frames_per_sec != self.backward_compatible_output_frames_per_sec:
|
|
||||||
logger.deprecation(
|
|
||||||
'`env.metadata["video.output_frames_per_second"] is marked as deprecated and will be replaced '
|
|
||||||
'with `env.metadata["render_fps"]` see https://github.com/openai/gym/pull/2654 for more details'
|
|
||||||
)
|
|
||||||
self.output_frames_per_sec = self.backward_compatible_output_frames_per_sec
|
|
||||||
|
|
||||||
self.encoder: Optional[
|
|
||||||
Union[TextEncoder, ImageEncoder]
|
|
||||||
] = None # lazily start the process
|
|
||||||
self.broken = False
|
self.broken = False
|
||||||
|
|
||||||
# Dump metadata
|
# Dump metadata
|
||||||
self.metadata = metadata or {}
|
self.metadata = metadata or {}
|
||||||
self.metadata["content_type"] = (
|
self.metadata["content_type"] = "video/mp4"
|
||||||
"video/vnd.openai.ansivid" if self.ansi_mode else "video/mp4"
|
|
||||||
)
|
|
||||||
self.metadata_path = f"{path_base}.meta.json"
|
self.metadata_path = f"{path_base}.meta.json"
|
||||||
self.write_metadata()
|
self.write_metadata()
|
||||||
|
|
||||||
logger.info(f"Starting new video recorder writing to {self.path}")
|
logger.info(f"Starting new video recorder writing to {self.path}")
|
||||||
self.empty = True
|
self.recorded_frames = []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def functional(self):
|
def functional(self):
|
||||||
@@ -181,14 +107,10 @@ class VideoRecorder: # TODO: remove with gym 1.0
|
|||||||
|
|
||||||
def capture_frame(self):
|
def capture_frame(self):
|
||||||
"""Render the given `env` and add the resulting frame to the video."""
|
"""Render the given `env` and add the resulting frame to the video."""
|
||||||
if self.render_mode is None:
|
frame = self.env.render()
|
||||||
frame = self.env.render(mode="rgb_array")
|
|
||||||
else:
|
|
||||||
frame = self.env.render()
|
|
||||||
if isinstance(frame, List):
|
if isinstance(frame, List):
|
||||||
self.render_history += frame
|
self.render_history += frame
|
||||||
frame = frame[-1]
|
frame = frame[-1]
|
||||||
self.last_frame = frame
|
|
||||||
|
|
||||||
if not self.functional:
|
if not self.functional:
|
||||||
return
|
return
|
||||||
@@ -211,10 +133,7 @@ class VideoRecorder: # TODO: remove with gym 1.0
|
|||||||
)
|
)
|
||||||
self.broken = True
|
self.broken = True
|
||||||
else:
|
else:
|
||||||
if self.ansi_mode:
|
self.recorded_frames.append(frame)
|
||||||
self._encode_ansi_frame(frame)
|
|
||||||
else:
|
|
||||||
self._encode_image_frame(frame)
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Flush all data to disk and close any open frame encoders."""
|
"""Flush all data to disk and close any open frame encoders."""
|
||||||
@@ -225,34 +144,16 @@ class VideoRecorder: # TODO: remove with gym 1.0
|
|||||||
self.env.close()
|
self.env.close()
|
||||||
|
|
||||||
# Close the encoder
|
# Close the encoder
|
||||||
if self.encoder:
|
if len(self.recorded_frames) > 0:
|
||||||
logger.debug("Closing video encoder: path=%s", self.path)
|
logger.debug("Closing video encoder: path=%s", self.path)
|
||||||
self.encoder.close()
|
clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec)
|
||||||
self.encoder = None
|
clip.write_videofile(self.path)
|
||||||
else:
|
else:
|
||||||
# No frames captured. Set metadata, and remove the empty output file.
|
# No frames captured. Set metadata.
|
||||||
os.remove(self.path)
|
|
||||||
|
|
||||||
if self.metadata is None:
|
if self.metadata is None:
|
||||||
self.metadata = {}
|
self.metadata = {}
|
||||||
self.metadata["empty"] = True
|
self.metadata["empty"] = True
|
||||||
|
|
||||||
# If broken, get rid of the output file, otherwise we'd leak it.
|
|
||||||
if self.broken:
|
|
||||||
logger.info(
|
|
||||||
"Cleaning up paths for broken video recorder: path=%s metadata_path=%s",
|
|
||||||
self.path,
|
|
||||||
self.metadata_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Might have crashed before even starting the output file, don't try to remove in that case.
|
|
||||||
if os.path.exists(self.path):
|
|
||||||
os.remove(self.path)
|
|
||||||
|
|
||||||
if self.metadata is None:
|
|
||||||
self.metadata = {}
|
|
||||||
self.metadata["broken"] = True
|
|
||||||
|
|
||||||
self.write_metadata()
|
self.write_metadata()
|
||||||
|
|
||||||
# Stop tracking this for autoclose
|
# Stop tracking this for autoclose
|
||||||
@@ -267,257 +168,3 @@ class VideoRecorder: # TODO: remove with gym 1.0
|
|||||||
"""Closes the environment correctly when the recorder is deleted."""
|
"""Closes the environment correctly when the recorder is deleted."""
|
||||||
# Make sure we've closed up shop when garbage collecting
|
# Make sure we've closed up shop when garbage collecting
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
def _encode_ansi_frame(self, frame):
|
|
||||||
if not self.encoder:
|
|
||||||
self.encoder = TextEncoder(self.path, self.frames_per_sec)
|
|
||||||
self.metadata["encoder_version"] = self.encoder.version_info
|
|
||||||
self.encoder.capture_frame(frame)
|
|
||||||
self.empty = False
|
|
||||||
|
|
||||||
def _encode_image_frame(self, frame):
|
|
||||||
if not self.encoder:
|
|
||||||
self.encoder = ImageEncoder(
|
|
||||||
self.path, frame.shape, self.frames_per_sec, self.output_frames_per_sec
|
|
||||||
)
|
|
||||||
self.metadata["encoder_version"] = self.encoder.version_info
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.encoder.capture_frame(frame)
|
|
||||||
except error.InvalidFrame as e:
|
|
||||||
logger.warn("Tried to pass invalid video frame, marking as broken: %s", e)
|
|
||||||
self.broken = True
|
|
||||||
else:
|
|
||||||
self.empty = False
|
|
||||||
|
|
||||||
|
|
||||||
class TextEncoder:
|
|
||||||
"""Store a moving picture made out of ANSI frames.
|
|
||||||
|
|
||||||
Format adapted from https://github.com/asciinema/asciinema/blob/master/doc/asciicast-v1.md
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, output_path: str, frames_per_sec: int):
|
|
||||||
"""Stores a moving picture for an environment with ANSI frames.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
output_path: The output path of the frames
|
|
||||||
frames_per_sec: The number of frames per seconds for the output video
|
|
||||||
"""
|
|
||||||
self.output_path = output_path
|
|
||||||
self.frames_per_sec = frames_per_sec
|
|
||||||
self.frames = []
|
|
||||||
|
|
||||||
def capture_frame(self, frame: Union[str, StringIO]):
|
|
||||||
"""Captures an ANSI frame and adds it to the frames.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
frame: A string or StringIO frame
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
InvalidFrame: Wrong type for a frame, expects text frame to be a string or StringIO
|
|
||||||
"""
|
|
||||||
if isinstance(frame, str):
|
|
||||||
string = frame
|
|
||||||
elif isinstance(frame, StringIO):
|
|
||||||
string = frame.getvalue()
|
|
||||||
else:
|
|
||||||
raise error.InvalidFrame(
|
|
||||||
f"Wrong type {type(frame)} for {frame}: text frame must be a string or StringIO"
|
|
||||||
)
|
|
||||||
|
|
||||||
frame_bytes = string.encode("utf-8")
|
|
||||||
|
|
||||||
if frame_bytes[-1:] != b"\n":
|
|
||||||
raise error.InvalidFrame(f'Frame must end with a newline: """{string}"""')
|
|
||||||
|
|
||||||
if b"\r" in frame_bytes:
|
|
||||||
raise error.InvalidFrame(
|
|
||||||
f'Frame contains carriage returns (only newlines are allowed: """{string}"""'
|
|
||||||
)
|
|
||||||
|
|
||||||
self.frames.append(frame_bytes)
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""Closes the text encoder, dumping all data to output path."""
|
|
||||||
# frame_duration = float(1) / self.frames_per_sec
|
|
||||||
frame_duration = 0.5
|
|
||||||
|
|
||||||
# Turn frames into events: clear screen beforehand
|
|
||||||
# https://rosettacode.org/wiki/Terminal_control/Clear_the_screen#Python
|
|
||||||
# https://rosettacode.org/wiki/Terminal_control/Cursor_positioning#Python
|
|
||||||
clear_code = b"%c[2J\033[1;1H" % (27)
|
|
||||||
# Decode the bytes as UTF-8 since JSON may only contain UTF-8
|
|
||||||
events = [
|
|
||||||
(
|
|
||||||
frame_duration,
|
|
||||||
(clear_code + frame.replace(b"\n", b"\r\n")).decode("utf-8"),
|
|
||||||
)
|
|
||||||
for frame in self.frames
|
|
||||||
]
|
|
||||||
|
|
||||||
# Calculate frame size from the largest frames.
|
|
||||||
# Add some padding since we'll get cut off otherwise.
|
|
||||||
height = max(frame.count(b"\n") for frame in self.frames) + 1
|
|
||||||
width = (
|
|
||||||
max(max(len(line) for line in frame.split(b"\n")) for frame in self.frames)
|
|
||||||
+ 2
|
|
||||||
)
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"version": 1,
|
|
||||||
"width": width,
|
|
||||||
"height": height,
|
|
||||||
"duration": len(self.frames) * frame_duration,
|
|
||||||
"command": "-",
|
|
||||||
"title": "gym VideoRecorder episode",
|
|
||||||
"env": {}, # could add some env metadata here
|
|
||||||
"stdout": events,
|
|
||||||
}
|
|
||||||
|
|
||||||
with open(self.output_path, "w") as f:
|
|
||||||
json.dump(data, f)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def version_info(self):
|
|
||||||
"""Returns the version info, backend=TextEncoder and Version number=1."""
|
|
||||||
return {"backend": "TextEncoder", "version": 1}
|
|
||||||
|
|
||||||
|
|
||||||
class ImageEncoder:
|
|
||||||
"""Captures image based frames of environments for Video Recorder."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
output_path: str,
|
|
||||||
frame_shape: Tuple[int, int, int],
|
|
||||||
frames_per_sec: int,
|
|
||||||
output_frames_per_sec: int,
|
|
||||||
):
|
|
||||||
"""Encoder for capturing image based frames of environment for Video Recorder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
output_path: The output data path
|
|
||||||
frame_shape: The expected frame shape, a tuple of height, weight and channels (3 or 4)
|
|
||||||
frames_per_sec: The number of frames per second the environment runs at
|
|
||||||
output_frames_per_sec: The output number of frames per second for the video
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
InvalidFrame: Expects frame to have shape (w,h,3) or (w,h,4)
|
|
||||||
DependencyNotInstalled: Found neither the ffmpeg nor avconv executables.
|
|
||||||
"""
|
|
||||||
self.proc: Optional[subprocess.Popen] = None
|
|
||||||
self.output_path = output_path
|
|
||||||
# Frame shape should be lines-first, so w and h are swapped
|
|
||||||
h, w, pixfmt = frame_shape
|
|
||||||
if pixfmt != 3 and pixfmt != 4:
|
|
||||||
raise error.InvalidFrame(
|
|
||||||
f"Your frame has shape {frame_shape}, but we require (w,h,3) or (w,h,4), "
|
|
||||||
"i.e., RGB values for a w-by-h image, with an optional alpha channel."
|
|
||||||
)
|
|
||||||
self.wh = (w, h)
|
|
||||||
self.includes_alpha = pixfmt == 4
|
|
||||||
self.frame_shape = frame_shape
|
|
||||||
self.frames_per_sec = frames_per_sec
|
|
||||||
self.output_frames_per_sec = output_frames_per_sec
|
|
||||||
|
|
||||||
if shutil.which("avconv") is not None:
|
|
||||||
self.backend = "avconv"
|
|
||||||
elif shutil.which("ffmpeg") is not None:
|
|
||||||
self.backend = "ffmpeg"
|
|
||||||
elif pkgutil.find_loader("imageio_ffmpeg"):
|
|
||||||
import imageio_ffmpeg
|
|
||||||
|
|
||||||
self.backend = imageio_ffmpeg.get_ffmpeg_exe()
|
|
||||||
else:
|
|
||||||
raise error.DependencyNotInstalled(
|
|
||||||
"Found neither the ffmpeg nor avconv executables. "
|
|
||||||
"On OS X, you can install ffmpeg via `brew install ffmpeg`. "
|
|
||||||
"On most Ubuntu variants, `sudo apt-get install ffmpeg` should do it. "
|
|
||||||
"On Ubuntu 14.04, however, you'll need to install avconv with `sudo apt-get install libav-tools`. "
|
|
||||||
"Alternatively, please install imageio-ffmpeg with `pip install imageio-ffmpeg`"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.start()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def version_info(self):
|
|
||||||
"""Returns the version info: backend, version and cmdline."""
|
|
||||||
return {
|
|
||||||
"backend": self.backend,
|
|
||||||
"version": str(
|
|
||||||
subprocess.check_output(
|
|
||||||
[self.backend, "-version"], stderr=subprocess.STDOUT
|
|
||||||
)
|
|
||||||
),
|
|
||||||
"cmdline": self.cmdline,
|
|
||||||
}
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
"""Starts a subprocess using the backend and cmdline."""
|
|
||||||
self.cmdline = (
|
|
||||||
self.backend,
|
|
||||||
"-nostats",
|
|
||||||
"-loglevel",
|
|
||||||
"error", # suppress warnings
|
|
||||||
"-y",
|
|
||||||
# input
|
|
||||||
"-f",
|
|
||||||
"rawvideo",
|
|
||||||
"-s:v",
|
|
||||||
"{}x{}".format(*self.wh),
|
|
||||||
"-pix_fmt",
|
|
||||||
("rgb32" if self.includes_alpha else "rgb24"),
|
|
||||||
"-framerate",
|
|
||||||
"%d" % self.frames_per_sec,
|
|
||||||
"-i",
|
|
||||||
"-", # this used to be /dev/stdin, which is not Windows-friendly
|
|
||||||
# output
|
|
||||||
"-vf",
|
|
||||||
"scale=trunc(iw/2)*2:trunc(ih/2)*2",
|
|
||||||
"-vcodec",
|
|
||||||
"libx264",
|
|
||||||
"-pix_fmt",
|
|
||||||
"yuv420p",
|
|
||||||
"-r",
|
|
||||||
"%d" % self.output_frames_per_sec,
|
|
||||||
self.output_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug('Starting %s with "%s"', self.backend, " ".join(self.cmdline))
|
|
||||||
if hasattr(os, "setsid"): # setsid not present on Windows
|
|
||||||
self.proc = subprocess.Popen(
|
|
||||||
self.cmdline, stdin=subprocess.PIPE, preexec_fn=os.setsid
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.proc = subprocess.Popen(self.cmdline, stdin=subprocess.PIPE)
|
|
||||||
|
|
||||||
def capture_frame(self, frame: Union[np.ndarray, np.generic]):
|
|
||||||
"""Captures a frame writing it to the backend subprocess."""
|
|
||||||
if not isinstance(frame, (np.ndarray, np.generic)):
|
|
||||||
raise error.InvalidFrame(
|
|
||||||
f"Wrong type {type(frame)} for {frame} (must be np.ndarray or np.generic)"
|
|
||||||
)
|
|
||||||
if frame.shape != self.frame_shape:
|
|
||||||
raise error.InvalidFrame(
|
|
||||||
f"Your frame has shape {frame.shape}, but the VideoRecorder is configured for shape {self.frame_shape}."
|
|
||||||
)
|
|
||||||
if frame.dtype != np.uint8:
|
|
||||||
raise error.InvalidFrame(
|
|
||||||
f"Your frame has data type {frame.dtype}, but we require uint8 (i.e. RGB values from 0-255)."
|
|
||||||
)
|
|
||||||
|
|
||||||
assert self.proc is not None and self.proc.stdin is not None
|
|
||||||
try:
|
|
||||||
self.proc.stdin.write(frame.tobytes())
|
|
||||||
except Exception:
|
|
||||||
stdout, stderr = self.proc.communicate()
|
|
||||||
logger.error("VideoRecorder encoder failed: %s", stderr)
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""Closes the Image encoder."""
|
|
||||||
assert self.proc is not None and self.proc.stdin is not None
|
|
||||||
self.proc.stdin.close()
|
|
||||||
ret = self.proc.wait()
|
|
||||||
if ret != 0:
|
|
||||||
logger.error(f"VideoRecorder encoder exited with status {ret}")
|
|
||||||
|
@@ -122,7 +122,6 @@ class RecordVideo(gym.Wrapper):
|
|||||||
env=self.env,
|
env=self.env,
|
||||||
base_path=base_path,
|
base_path=base_path,
|
||||||
metadata={"step_id": self.step_id, "episode_id": self.episode_id},
|
metadata={"step_id": self.step_id, "episode_id": self.episode_id},
|
||||||
internal_backend_use=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.video_recorder.capture_frame()
|
self.video_recorder.capture_frame()
|
||||||
|
2
setup.py
2
setup.py
@@ -20,7 +20,7 @@ extras = {
|
|||||||
"mujoco_py": ["mujoco_py<2.2,>=2.1"],
|
"mujoco_py": ["mujoco_py<2.2,>=2.1"],
|
||||||
"mujoco": ["mujoco==2.2.0", "imageio>=2.14.1"],
|
"mujoco": ["mujoco==2.2.0", "imageio>=2.14.1"],
|
||||||
"toy_text": ["pygame==2.1.0"],
|
"toy_text": ["pygame==2.1.0"],
|
||||||
"other": ["lz4>=3.1.0", "opencv-python>=3.0", "matplotlib>=3.0"],
|
"other": ["lz4>=3.1.0", "opencv-python>=3.0", "matplotlib>=3.0", "moviepy>=1.0.0"],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Testing dependency groups.
|
# Testing dependency groups.
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Callable
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pygame
|
import pygame
|
||||||
@@ -22,6 +22,9 @@ class DummyEnvSpec:
|
|||||||
|
|
||||||
|
|
||||||
class DummyPlayEnv(gym.Env):
|
class DummyPlayEnv(gym.Env):
|
||||||
|
def __init__(self, render_mode: Optional[str] = None):
|
||||||
|
self.render_mode = render_mode
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
obs = np.zeros((1, 1))
|
obs = np.zeros((1, 1))
|
||||||
rew, done, info = 1, False, {}
|
rew, done, info = 1, False, {}
|
||||||
@@ -30,7 +33,7 @@ class DummyPlayEnv(gym.Env):
|
|||||||
def reset(self, seed=None):
|
def reset(self, seed=None):
|
||||||
...
|
...
|
||||||
|
|
||||||
def render(self, mode="rgb_array"):
|
def render(self):
|
||||||
return np.zeros((1, 1))
|
return np.zeros((1, 1))
|
||||||
|
|
||||||
|
|
||||||
@@ -73,13 +76,13 @@ def close_pygame():
|
|||||||
|
|
||||||
|
|
||||||
def test_play_relevant_keys():
|
def test_play_relevant_keys():
|
||||||
env = DummyPlayEnv()
|
env = DummyPlayEnv(render_mode="single_rgb_array")
|
||||||
game = PlayableGame(env, dummy_keys_to_action())
|
game = PlayableGame(env, dummy_keys_to_action())
|
||||||
assert game.relevant_keys == {RELEVANT_KEY_1, RELEVANT_KEY_2}
|
assert game.relevant_keys == {RELEVANT_KEY_1, RELEVANT_KEY_2}
|
||||||
|
|
||||||
|
|
||||||
def test_play_relevant_keys_no_mapping():
|
def test_play_relevant_keys_no_mapping():
|
||||||
env = DummyPlayEnv()
|
env = DummyPlayEnv(render_mode="single_rgb_array")
|
||||||
env.spec = DummyEnvSpec("DummyPlayEnv")
|
env.spec = DummyEnvSpec("DummyPlayEnv")
|
||||||
|
|
||||||
with pytest.raises(MissingKeysToAction):
|
with pytest.raises(MissingKeysToAction):
|
||||||
@@ -88,27 +91,27 @@ def test_play_relevant_keys_no_mapping():
|
|||||||
|
|
||||||
def test_play_relevant_keys_with_env_attribute():
|
def test_play_relevant_keys_with_env_attribute():
|
||||||
"""Env has a keys_to_action attribute"""
|
"""Env has a keys_to_action attribute"""
|
||||||
env = DummyPlayEnv()
|
env = DummyPlayEnv(render_mode="single_rgb_array")
|
||||||
env.get_keys_to_action = dummy_keys_to_action
|
env.get_keys_to_action = dummy_keys_to_action
|
||||||
game = PlayableGame(env)
|
game = PlayableGame(env)
|
||||||
assert game.relevant_keys == {RELEVANT_KEY_1, RELEVANT_KEY_2}
|
assert game.relevant_keys == {RELEVANT_KEY_1, RELEVANT_KEY_2}
|
||||||
|
|
||||||
|
|
||||||
def test_video_size_no_zoom():
|
def test_video_size_no_zoom():
|
||||||
env = DummyPlayEnv()
|
env = DummyPlayEnv(render_mode="single_rgb_array")
|
||||||
game = PlayableGame(env, dummy_keys_to_action())
|
game = PlayableGame(env, dummy_keys_to_action())
|
||||||
assert game.video_size == list(env.render().shape)
|
assert game.video_size == list(env.render().shape)
|
||||||
|
|
||||||
|
|
||||||
def test_video_size_zoom():
|
def test_video_size_zoom():
|
||||||
env = DummyPlayEnv()
|
env = DummyPlayEnv(render_mode="single_rgb_array")
|
||||||
zoom = 2.2
|
zoom = 2.2
|
||||||
game = PlayableGame(env, dummy_keys_to_action(), zoom)
|
game = PlayableGame(env, dummy_keys_to_action(), zoom)
|
||||||
assert game.video_size == tuple(int(shape * zoom) for shape in env.render().shape)
|
assert game.video_size == tuple(int(shape * zoom) for shape in env.render().shape)
|
||||||
|
|
||||||
|
|
||||||
def test_keyboard_quit_event():
|
def test_keyboard_quit_event():
|
||||||
env = DummyPlayEnv()
|
env = DummyPlayEnv(render_mode="single_rgb_array")
|
||||||
game = PlayableGame(env, dummy_keys_to_action())
|
game = PlayableGame(env, dummy_keys_to_action())
|
||||||
event = Event(pygame.KEYDOWN, {"key": pygame.K_ESCAPE})
|
event = Event(pygame.KEYDOWN, {"key": pygame.K_ESCAPE})
|
||||||
assert game.running is True
|
assert game.running is True
|
||||||
@@ -117,7 +120,7 @@ def test_keyboard_quit_event():
|
|||||||
|
|
||||||
|
|
||||||
def test_pygame_quit_event():
|
def test_pygame_quit_event():
|
||||||
env = DummyPlayEnv()
|
env = DummyPlayEnv(render_mode="single_rgb_array")
|
||||||
game = PlayableGame(env, dummy_keys_to_action())
|
game = PlayableGame(env, dummy_keys_to_action())
|
||||||
event = Event(pygame.QUIT)
|
event = Event(pygame.QUIT)
|
||||||
assert game.running is True
|
assert game.running is True
|
||||||
@@ -126,7 +129,7 @@ def test_pygame_quit_event():
|
|||||||
|
|
||||||
|
|
||||||
def test_keyboard_relevant_keydown_event():
|
def test_keyboard_relevant_keydown_event():
|
||||||
env = DummyPlayEnv()
|
env = DummyPlayEnv(render_mode="single_rgb_array")
|
||||||
game = PlayableGame(env, dummy_keys_to_action())
|
game = PlayableGame(env, dummy_keys_to_action())
|
||||||
event = Event(pygame.KEYDOWN, {"key": RELEVANT_KEY_1})
|
event = Event(pygame.KEYDOWN, {"key": RELEVANT_KEY_1})
|
||||||
game.process_event(event)
|
game.process_event(event)
|
||||||
@@ -134,7 +137,7 @@ def test_keyboard_relevant_keydown_event():
|
|||||||
|
|
||||||
|
|
||||||
def test_keyboard_irrelevant_keydown_event():
|
def test_keyboard_irrelevant_keydown_event():
|
||||||
env = DummyPlayEnv()
|
env = DummyPlayEnv(render_mode="single_rgb_array")
|
||||||
game = PlayableGame(env, dummy_keys_to_action())
|
game = PlayableGame(env, dummy_keys_to_action())
|
||||||
event = Event(pygame.KEYDOWN, {"key": IRRELEVANT_KEY})
|
event = Event(pygame.KEYDOWN, {"key": IRRELEVANT_KEY})
|
||||||
game.process_event(event)
|
game.process_event(event)
|
||||||
@@ -142,7 +145,7 @@ def test_keyboard_irrelevant_keydown_event():
|
|||||||
|
|
||||||
|
|
||||||
def test_keyboard_keyup_event():
|
def test_keyboard_keyup_event():
|
||||||
env = DummyPlayEnv()
|
env = DummyPlayEnv(render_mode="single_rgb_array")
|
||||||
game = PlayableGame(env, dummy_keys_to_action())
|
game = PlayableGame(env, dummy_keys_to_action())
|
||||||
event = Event(pygame.KEYDOWN, {"key": RELEVANT_KEY_1})
|
event = Event(pygame.KEYDOWN, {"key": RELEVANT_KEY_1})
|
||||||
game.process_event(event)
|
game.process_event(event)
|
||||||
@@ -186,7 +189,7 @@ def test_play_loop_real_env():
|
|||||||
|
|
||||||
return obs_t, obs_tp1, action, rew, done, info
|
return obs_t, obs_tp1, action, rew, done, info
|
||||||
|
|
||||||
env = gym.make(ENV, disable_env_checker=True)
|
env = gym.make(ENV, render_mode="single_rgb_array", disable_env_checker=True)
|
||||||
env.reset(seed=SEED)
|
env.reset(seed=SEED)
|
||||||
keys_to_action = (
|
keys_to_action = (
|
||||||
dummy_keys_to_action_str() if str_keys else dummy_keys_to_action()
|
dummy_keys_to_action_str() if str_keys else dummy_keys_to_action()
|
||||||
@@ -199,7 +202,9 @@ def test_play_loop_real_env():
|
|||||||
action = keys_to_action[chr(e.key) if str_keys else (e.key,)]
|
action = keys_to_action[chr(e.key) if str_keys else (e.key,)]
|
||||||
obs, _, _, _ = env.step(action)
|
obs, _, _, _ = env.step(action)
|
||||||
|
|
||||||
env_play = gym.make(ENV, disable_env_checker=True)
|
env_play = gym.make(
|
||||||
|
ENV, render_mode="single_rgb_array", disable_env_checker=True
|
||||||
|
)
|
||||||
if apply_wrapper:
|
if apply_wrapper:
|
||||||
env_play = KeysToActionWrapper(env, keys_to_action=keys_to_action)
|
env_play = KeysToActionWrapper(env, keys_to_action=keys_to_action)
|
||||||
assert hasattr(env_play, "get_keys_to_action")
|
assert hasattr(env_play, "get_keys_to_action")
|
||||||
|
@@ -21,7 +21,7 @@ def test_gym_make_order_enforcing(spec):
|
|||||||
def test_order_enforcing():
|
def test_order_enforcing():
|
||||||
"""Checks that the order enforcing works as expected, raising an error before reset is called and not after."""
|
"""Checks that the order enforcing works as expected, raising an error before reset is called and not after."""
|
||||||
# The reason for not using gym.make is that all environments are by default wrapped in the order enforcing wrapper
|
# The reason for not using gym.make is that all environments are by default wrapped in the order enforcing wrapper
|
||||||
env = CartPoleEnv()
|
env = CartPoleEnv(render_mode="rgb_array")
|
||||||
assert not has_wrapper(env, OrderEnforcing)
|
assert not has_wrapper(env, OrderEnforcing)
|
||||||
|
|
||||||
# Assert that the order enforcing works for step and render before reset
|
# Assert that the order enforcing works for step and render before reset
|
||||||
@@ -30,16 +30,16 @@ def test_order_enforcing():
|
|||||||
with pytest.raises(ResetNeeded):
|
with pytest.raises(ResetNeeded):
|
||||||
order_enforced_env.step(0)
|
order_enforced_env.step(0)
|
||||||
with pytest.raises(ResetNeeded):
|
with pytest.raises(ResetNeeded):
|
||||||
order_enforced_env.render(mode="rgb_array")
|
order_enforced_env.render()
|
||||||
assert order_enforced_env.has_reset is False
|
assert order_enforced_env.has_reset is False
|
||||||
|
|
||||||
# Assert that the Assertion errors are not raised after reset
|
# Assert that the Assertion errors are not raised after reset
|
||||||
order_enforced_env.reset()
|
order_enforced_env.reset()
|
||||||
assert order_enforced_env.has_reset is True
|
assert order_enforced_env.has_reset is True
|
||||||
order_enforced_env.step(0)
|
order_enforced_env.step(0)
|
||||||
order_enforced_env.render(mode="rgb_array")
|
order_enforced_env.render()
|
||||||
|
|
||||||
# Assert that with disable_render_order_enforcing works, the environment has already been reset
|
# Assert that with disable_render_order_enforcing works, the environment has already been reset
|
||||||
env = CartPoleEnv()
|
env = CartPoleEnv(render_mode="rgb_array")
|
||||||
env = OrderEnforcing(env, disable_render_order_enforcing=True)
|
env = OrderEnforcing(env, disable_render_order_enforcing=True)
|
||||||
env.render(mode="rgb_array") # no assertion error
|
env.render() # no assertion error
|
||||||
|
@@ -14,7 +14,7 @@ class BrokenRecordableEnv(gym.Env):
|
|||||||
def __init__(self, render_mode="rgb_array"):
|
def __init__(self, render_mode="rgb_array"):
|
||||||
self.render_mode = render_mode
|
self.render_mode = render_mode
|
||||||
|
|
||||||
def render(self, mode="human"):
|
def render(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -24,7 +24,7 @@ class UnrecordableEnv(gym.Env):
|
|||||||
def __init__(self, render_mode=None):
|
def __init__(self, render_mode=None):
|
||||||
self.render_mode = render_mode
|
self.render_mode = render_mode
|
||||||
|
|
||||||
def render(self, mode="human"):
|
def render(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -33,15 +33,9 @@ def test_record_simple():
|
|||||||
rec = VideoRecorder(env)
|
rec = VideoRecorder(env)
|
||||||
env.reset()
|
env.reset()
|
||||||
rec.capture_frame()
|
rec.capture_frame()
|
||||||
assert rec.encoder is not None
|
|
||||||
proc = rec.encoder.proc
|
|
||||||
|
|
||||||
assert proc is not None and proc.poll() is None # subprocess is running
|
|
||||||
|
|
||||||
rec.close()
|
rec.close()
|
||||||
|
|
||||||
assert proc.poll() is not None # subprocess is terminated
|
|
||||||
assert not rec.empty
|
|
||||||
assert not rec.broken
|
assert not rec.broken
|
||||||
assert os.path.exists(rec.path)
|
assert os.path.exists(rec.path)
|
||||||
f = open(rec.path)
|
f = open(rec.path)
|
||||||
@@ -56,21 +50,16 @@ def test_autoclose():
|
|||||||
rec.capture_frame()
|
rec.capture_frame()
|
||||||
|
|
||||||
rec_path = rec.path
|
rec_path = rec.path
|
||||||
assert rec.encoder is not None
|
|
||||||
proc = rec.encoder.proc
|
|
||||||
|
|
||||||
assert proc is not None and proc.poll() is None # subprocess is running
|
|
||||||
|
|
||||||
# The function ends without an explicit `rec.close()` call
|
# The function ends without an explicit `rec.close()` call
|
||||||
# The Python interpreter will implicitly do `del rec` on garbage cleaning
|
# The Python interpreter will implicitly do `del rec` on garbage cleaning
|
||||||
return rec_path, proc
|
return rec_path
|
||||||
|
|
||||||
rec_path, proc = record()
|
rec_path = record()
|
||||||
|
|
||||||
gc.collect() # do explicit garbage collection for test
|
gc.collect() # do explicit garbage collection for test
|
||||||
time.sleep(5) # wait for subprocess exiting
|
time.sleep(5) # wait for subprocess exiting
|
||||||
|
|
||||||
assert proc is not None and proc.poll() is not None # subprocess is terminated
|
|
||||||
assert os.path.exists(rec_path)
|
assert os.path.exists(rec_path)
|
||||||
f = open(rec_path)
|
f = open(rec_path)
|
||||||
assert os.fstat(f.fileno()).st_size > 100
|
assert os.fstat(f.fileno()).st_size > 100
|
||||||
@@ -80,7 +69,6 @@ def test_no_frames():
|
|||||||
env = BrokenRecordableEnv()
|
env = BrokenRecordableEnv()
|
||||||
rec = VideoRecorder(env)
|
rec = VideoRecorder(env)
|
||||||
rec.close()
|
rec.close()
|
||||||
assert rec.empty
|
|
||||||
assert rec.functional
|
assert rec.functional
|
||||||
assert not os.path.exists(rec.path)
|
assert not os.path.exists(rec.path)
|
||||||
|
|
||||||
@@ -98,7 +86,6 @@ def test_record_breaking_render_method():
|
|||||||
rec = VideoRecorder(env)
|
rec = VideoRecorder(env)
|
||||||
rec.capture_frame()
|
rec.capture_frame()
|
||||||
rec.close()
|
rec.close()
|
||||||
assert rec.empty
|
|
||||||
assert rec.broken
|
assert rec.broken
|
||||||
assert not os.path.exists(rec.path)
|
assert not os.path.exists(rec.path)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user