Render API (#2671)

* add pygame GUI for frozen_lake.py env

* add new line at EOF

* pre-commit reformat

* improve graphics

* new images and dynamic window size

* darker tile borders and fix ICC profile

* pre-commit hook

* adjust elf and stool size

* Update frozen_lake.py

* reformat

* fix #2600

* #2600

* add rgb_array support

* reformat

* test render api change on FrozenLake

* add render support for reset on frozenlake

* add clock on pygame render

* new render api for blackjack

* new render api for cliffwalking

* new render api for Env class

* update reset method, lunar and Env

* fix wrapper

* fix reset lunar

* new render api for box2d envs

* new render api for mujoco envs

* fix bug

* new render api for classic control envs

* fix tests

* add render_mode None for CartPole

* new render api for test fake envs

* pre-commit hook

* fix FrozenLake

* fix FrozenLake

* more render_mode to super - frozenlake

* remove kwargs from frozen_lake new

* pre-commit hook

* add deprecated render method

* add backwards compatibility

* fix test

* add _render

* move pygame.init() (avoid pygame dependency on init)

* fix pygame dependencies

* remove collect_render() maintain multi-behaviours .render()

* add type hints

* fix renderer

* don't call .render() with None

* improve docstring

* add single_rgb_array to all envs

* remove None from metadata["render_modes"]

* add type hints to test_env_checkers

* fix lint

* add comments to renderer

* add comments to single_depth_array and single_state_pixels

* reformat

* add deprecation warnings and env.render_mode declaration

* fix lint

* reformat

* fix tests

* add docs

* fix car racing determinism

* remove warning test envs, customizable modes on renderer

* remove commments and add todo for env_checker

* fix car racing

* replace render mode check with assert

* update new mujoco

* reformat

* reformat

* change metaclass definition

* fix tests

* implement mark suggestions (test, docs, sets)

* check_render

Co-authored-by: J K Terry <jkterry0@gmail.com>
This commit is contained in:
Omar Younis
2022-06-08 00:20:56 +02:00
committed by GitHub
parent 66c431d4b3
commit 9acf9cd367
58 changed files with 950 additions and 304 deletions

View File

@@ -1,6 +1,16 @@
"""Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper.""" """Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper."""
import sys import sys
from typing import Generic, Optional, SupportsFloat, Tuple, TypeVar, Union from typing import (
Any,
Dict,
Generic,
List,
Optional,
SupportsFloat,
Tuple,
TypeVar,
Union,
)
from gym import spaces from gym import spaces
from gym.logger import deprecation, warn from gym.logger import deprecation, warn
@@ -14,6 +24,44 @@ if sys.version_info == (3, 6):
ObsType = TypeVar("ObsType") ObsType = TypeVar("ObsType")
ActType = TypeVar("ActType") ActType = TypeVar("ActType")
RenderFrame = TypeVar("RenderFrame")
class _EnvDecorator(type): # TODO: remove with gym 1.0
"""Metaclass used for adding deprecation warning to the mode kwarg in the render method."""
def __new__(cls, name, bases, attr):
if "render" in attr.keys():
attr["render"] = _EnvDecorator._deprecate_mode(attr["render"])
return super().__new__(cls, name, bases, attr)
@staticmethod
def _deprecate_mode(render_func): # type: ignore
render_return = Optional[Union[RenderFrame, List[RenderFrame]]]
def render(
self: object, *args: Tuple[Any], **kwargs: Dict[str, Any]
) -> render_return:
if "mode" in kwargs.keys():
deprecation(
"The argument mode in render method is deprecated; "
"use render_mode during environment initialization instead.\n"
"See here for more information: https://www.gymlibrary.ml/content/api/"
)
elif self.spec is not None and "render_mode" not in self.spec.kwargs.keys(): # type: ignore
deprecation(
"You are calling render method, "
"but you didn't specified the argument render_mode at environment initialization. "
"To maintain backward compatibility, the environment will render in human mode.\n"
"If you want to render in human mode, initialize the environment in this way: "
"gym.make('EnvName', render_mode='human') and don't call the render method.\n"
"See here for more information: https://www.gymlibrary.ml/content/api/"
)
return render_func(self, *args, **kwargs)
return render
class Env(Generic[ObsType, ActType]): class Env(Generic[ObsType, ActType]):
@@ -43,8 +91,11 @@ class Env(Generic[ObsType, ActType]):
Note: a default reward range set to :math:`(-\infty,+\infty)` already exists. Set it if you want a narrower range. Note: a default reward range set to :math:`(-\infty,+\infty)` already exists. Set it if you want a narrower range.
""" """
__metaclass__ = _EnvDecorator
# Set this in SOME subclasses # Set this in SOME subclasses
metadata = {"render_modes": []} metadata = {"render_modes": []}
render_mode = None # define render_mode if your environment supports rendering
reward_range = (-float("inf"), float("inf")) reward_range = (-float("inf"), float("inf"))
spec = None spec = None
@@ -130,42 +181,34 @@ class Env(Generic[ObsType, ActType]):
if seed is not None: if seed is not None:
self._np_random, seed = seeding.np_random(seed) self._np_random, seed = seeding.np_random(seed)
def render(self, mode="human"): # TODO: remove kwarg mode with gym 1.0
"""Renders the environment. def render(self, mode="human") -> Optional[Union[RenderFrame, List[RenderFrame]]]:
"""Compute the render frames as specified by render_mode attribute during initialization of the environment.
A set of supported modes varies per environment. (And some The set of supported modes varies per environment. (And some
third-party environments may not support rendering at all.) third-party environments may not support rendering at all.)
By convention, if mode is: By convention, if render_mode is:
- human: render to the current display or terminal and - None (default): no render is computed.
return nothing. Usually for human consumption. - human: render return None.
- rgb_array: Return a numpy.ndarray with shape (x, y, 3), The environment is continuously rendered in the current display or terminal. Usually for human consumption.
representing RGB values for an x-by-y pixel image, suitable - single_rgb_array: return a single frame representing the current state of the environment.
for turning into a video. A frame is a numpy.ndarray with shape (x, y, 3) representing RGB values for an x-by-y pixel image.
- ansi: Return a string (str) or StringIO.StringIO containing a - rgb_array: return a list of frames representing the states of the environment since the last reset.
terminal-style text representation. The text can include newlines Each frame is a numpy.ndarray with shape (x, y, 3), as with single_rgb_array.
and ANSI escape sequences (e.g. for colors). - ansi: Return a list of strings (str) or StringIO.StringIO containing a
terminal-style text representation for each time step.
The text can include newlines and ANSI escape sequences (e.g. for colors).
Note:
Rendering computations is performed internally even if you don't call render().
To avoid this, you can set render_mode = None and, if the environment supports it,
call render() specifying the argument 'mode'.
Note: Note:
Make sure that your class's metadata 'render_modes' key includes Make sure that your class's metadata 'render_modes' key includes
the list of supported modes. It's recommended to call super() the list of supported modes. It's recommended to call super()
in implementations to use the functionality of this method. in implementations to use the functionality of this method.
Example:
>>> import numpy as np
>>> class MyEnv(Env):
... metadata = {'render_modes': ['human', 'rgb_array']}
...
... def render(self, mode='human'):
... if mode == 'rgb_array':
... return np.array(...) # return RGB frame suitable for video
... elif mode == 'human':
... ... # pop up a window and render
... else:
... super().render(mode=mode) # just raise an exception
Args:
mode: the mode to render with, valid modes are `env.metadata["render_modes"]`
""" """
raise NotImplementedError raise NotImplementedError

View File

@@ -9,6 +9,7 @@ import gym
from gym import error, spaces from gym import error, spaces
from gym.error import DependencyNotInstalled from gym.error import DependencyNotInstalled
from gym.utils import EzPickle from gym.utils import EzPickle
from gym.utils.renderer import Renderer
try: try:
import Box2D import Box2D
@@ -159,12 +160,13 @@ class BipedalWalker(gym.Env, EzPickle):
""" """
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": FPS} metadata = {
"render_modes": ["human", "rgb_array", "single_rgb_array"],
"render_fps": FPS,
}
def __init__(self, hardcore: bool = False): def __init__(self, render_mode: Optional[str] = None, hardcore: bool = False):
EzPickle.__init__(self) EzPickle.__init__(self)
self.screen = None
self.clock = None
self.isopen = True self.isopen = True
self.world = Box2D.b2World() self.world = Box2D.b2World()
@@ -252,6 +254,12 @@ class BipedalWalker(gym.Env, EzPickle):
# ] # ]
# state += [l.fraction for l in self.lidar] # state += [l.fraction for l in self.lidar]
assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
self.screen = None
self.clock = None
def _destroy(self): def _destroy(self):
if not self.terrain: if not self.terrain:
return return
@@ -500,6 +508,7 @@ class BipedalWalker(gym.Env, EzPickle):
return fraction return fraction
self.lidar = [LidarCallback() for _ in range(10)] self.lidar = [LidarCallback() for _ in range(10)]
self.renderer.reset()
if not return_info: if not return_info:
return self.step(np.array([0, 0, 0, 0]))[0] return self.step(np.array([0, 0, 0, 0]))[0]
else: else:
@@ -589,9 +598,18 @@ class BipedalWalker(gym.Env, EzPickle):
done = True done = True
if pos[0] > (TERRAIN_LENGTH - TERRAIN_GRASS) * TERRAIN_STEP: if pos[0] > (TERRAIN_LENGTH - TERRAIN_GRASS) * TERRAIN_STEP:
done = True done = True
self.renderer.render_step()
return np.array(state, dtype=np.float32), reward, done, {} return np.array(state, dtype=np.float32), reward, done, {}
def render(self, mode: str = "human"): def render(self, mode: str = "human"):
if self.render_mode is not None:
return self.renderer.get_renders()
else:
return self._render(mode)
def _render(self, mode: str = "human"):
assert mode in self.metadata["render_modes"]
try: try:
import pygame import pygame
from pygame import gfxdraw from pygame import gfxdraw
@@ -600,7 +618,7 @@ class BipedalWalker(gym.Env, EzPickle):
"pygame is not installed, run `pip install gym[box2d]`" "pygame is not installed, run `pip install gym[box2d]`"
) )
if self.screen is None: if self.screen is None and mode == "human":
pygame.init() pygame.init()
pygame.display.init() pygame.display.init()
self.screen = pygame.display.set_mode((VIEWPORT_W, VIEWPORT_H)) self.screen = pygame.display.set_mode((VIEWPORT_W, VIEWPORT_H))
@@ -653,18 +671,19 @@ class BipedalWalker(gym.Env, EzPickle):
self.lidar_render = (self.lidar_render + 1) % 100 self.lidar_render = (self.lidar_render + 1) % 100
i = self.lidar_render i = self.lidar_render
if i < 2 * len(self.lidar): if i < 2 * len(self.lidar):
l = ( single_lidar = (
self.lidar[i] self.lidar[i]
if i < len(self.lidar) if i < len(self.lidar)
else self.lidar[len(self.lidar) - i - 1] else self.lidar[len(self.lidar) - i - 1]
) )
pygame.draw.line( if hasattr(single_lidar, "p1") and hasattr(single_lidar, "p2"):
self.surf, pygame.draw.line(
color=(255, 0, 0), self.surf,
start_pos=(l.p1[0] * SCALE, l.p1[1] * SCALE), color=(255, 0, 0),
end_pos=(l.p2[0] * SCALE, l.p2[1] * SCALE), start_pos=(single_lidar.p1[0] * SCALE, single_lidar.p1[1] * SCALE),
width=1, end_pos=(single_lidar.p2[0] * SCALE, single_lidar.p2[1] * SCALE),
) width=1,
)
for obj in self.drawlist: for obj in self.drawlist:
for f in obj.fixtures: for f in obj.fixtures:
@@ -717,18 +736,16 @@ class BipedalWalker(gym.Env, EzPickle):
) )
self.surf = pygame.transform.flip(self.surf, False, True) self.surf = pygame.transform.flip(self.surf, False, True)
self.screen.blit(self.surf, (-self.scroll * SCALE, 0))
if mode == "human": if mode == "human":
self.screen.blit(self.surf, (-self.scroll * SCALE, 0))
pygame.event.pump() pygame.event.pump()
self.clock.tick(self.metadata["render_fps"]) self.clock.tick(self.metadata["render_fps"])
pygame.display.flip() pygame.display.flip()
elif mode in {"rgb_array", "single_rgb_array"}:
if mode == "rgb_array":
return np.transpose( return np.transpose(
np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2) np.array(pygame.surfarray.pixels3d(self.surf)), axes=(1, 0, 2)
) )
else:
return self.isopen
def close(self): def close(self):
if self.screen is not None: if self.screen is not None:
@@ -829,6 +846,5 @@ if __name__ == "__main__":
a[3] = knee_todo[1] a[3] = knee_todo[1]
a = np.clip(0.5 * a, -1.0, 1.0) a = np.clip(0.5 * a, -1.0, 1.0)
env.render()
if done: if done:
break break

View File

@@ -10,6 +10,7 @@ from gym import spaces
from gym.envs.box2d.car_dynamics import Car from gym.envs.box2d.car_dynamics import Car
from gym.error import DependencyNotInstalled, InvalidAction from gym.error import DependencyNotInstalled, InvalidAction
from gym.utils import EzPickle from gym.utils import EzPickle
from gym.utils.renderer import Renderer
try: try:
import Box2D import Box2D
@@ -151,12 +152,19 @@ class CarRacing(gym.Env, EzPickle):
""" """
metadata = { metadata = {
"render_modes": ["human", "rgb_array", "state_pixels"], "render_modes": [
"human",
"rgb_array",
"state_pixels",
"single_rgb_array",
"single_state_pixels",
],
"render_fps": FPS, "render_fps": FPS,
} }
def __init__( def __init__(
self, self,
render_mode: Optional[str] = None,
verbose: bool = False, verbose: bool = False,
lap_complete_percent: float = 0.95, lap_complete_percent: float = 0.95,
domain_randomize: bool = False, domain_randomize: bool = False,
@@ -170,6 +178,7 @@ class CarRacing(gym.Env, EzPickle):
self.contactListener_keepref = FrictionDetector(self, lap_complete_percent) self.contactListener_keepref = FrictionDetector(self, lap_complete_percent)
self.world = Box2D.b2World((0, 0), contactListener=self.contactListener_keepref) self.world = Box2D.b2World((0, 0), contactListener=self.contactListener_keepref)
self.screen = None self.screen = None
self.surf = None
self.clock = None self.clock = None
self.isopen = True self.isopen = True
self.invisible_state_window = None self.invisible_state_window = None
@@ -199,6 +208,10 @@ class CarRacing(gym.Env, EzPickle):
low=0, high=255, shape=(STATE_H, STATE_W, 3), dtype=np.uint8 low=0, high=255, shape=(STATE_H, STATE_W, 3), dtype=np.uint8
) )
assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
def _destroy(self): def _destroy(self):
if not self.road: if not self.road:
return return
@@ -441,6 +454,7 @@ class CarRacing(gym.Env, EzPickle):
) )
self.car = Car(self.world, *self.track[0][1:4]) self.car = Car(self.world, *self.track[0][1:4])
self.renderer.reset()
if not return_info: if not return_info:
return self.step(None)[0] return self.step(None)[0]
else: else:
@@ -466,7 +480,7 @@ class CarRacing(gym.Env, EzPickle):
self.world.Step(1.0 / FPS, 6 * 30, 2 * 30) self.world.Step(1.0 / FPS, 6 * 30, 2 * 30)
self.t += 1.0 / FPS self.t += 1.0 / FPS
self.state = self.render("state_pixels") self.state = self._render("single_state_pixels")
step_reward = 0 step_reward = 0
done = False done = False
@@ -484,9 +498,17 @@ class CarRacing(gym.Env, EzPickle):
done = True done = True
step_reward = -100 step_reward = -100
self.renderer.render_step()
return self.state, step_reward, done, {} return self.state, step_reward, done, {}
def render(self, mode: str = "human"): def render(self, mode: str = "human"):
if self.render_mode is not None:
return self.renderer.get_renders()
else:
return self._render(mode)
def _render(self, mode: str = "human"):
assert mode in self.metadata["render_modes"]
try: try:
import pygame import pygame
except ImportError: except ImportError:
@@ -496,7 +518,6 @@ class CarRacing(gym.Env, EzPickle):
pygame.font.init() pygame.font.init()
assert mode in ["human", "state_pixels", "rgb_array"]
if self.screen is None and mode == "human": if self.screen is None and mode == "human":
pygame.init() pygame.init()
pygame.display.init() pygame.display.init()
@@ -519,7 +540,13 @@ class CarRacing(gym.Env, EzPickle):
trans = (WINDOW_W / 2 + trans[0], WINDOW_H / 4 + trans[1]) trans = (WINDOW_W / 2 + trans[0], WINDOW_H / 4 + trans[1])
self._render_road(zoom, trans, angle) self._render_road(zoom, trans, angle)
self.car.draw(self.surf, zoom, trans, angle, mode != "state_pixels") self.car.draw(
self.surf,
zoom,
trans,
angle,
mode not in ["state_pixels", "single_state_pixels"],
)
self.surf = pygame.transform.flip(self.surf, False, True) self.surf = pygame.transform.flip(self.surf, False, True)
@@ -539,9 +566,9 @@ class CarRacing(gym.Env, EzPickle):
self.screen.blit(self.surf, (0, 0)) self.screen.blit(self.surf, (0, 0))
pygame.display.flip() pygame.display.flip()
if mode == "rgb_array": if mode in {"rgb_array", "single_rgb_array"}:
return self._create_image_array(self.surf, (VIDEO_W, VIDEO_H)) return self._create_image_array(self.surf, (VIDEO_W, VIDEO_H))
elif mode == "state_pixels": elif mode in {"state_pixels", "single_state_pixels"}:
return self._create_image_array(self.surf, (STATE_W, STATE_H)) return self._create_image_array(self.surf, (STATE_W, STATE_H))
else: else:
return self.isopen return self.isopen

View File

@@ -10,6 +10,7 @@ import gym
from gym import error, spaces from gym import error, spaces
from gym.error import DependencyNotInstalled from gym.error import DependencyNotInstalled
from gym.utils import EzPickle, colorize from gym.utils import EzPickle, colorize
from gym.utils.renderer import Renderer
try: try:
import Box2D import Box2D
@@ -171,10 +172,14 @@ class LunarLander(gym.Env, EzPickle):
Created by Oleg Klimov Created by Oleg Klimov
""" """
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": FPS} metadata = {
"render_modes": ["human", "rgb_array", "single_rgb_array"],
"render_fps": FPS,
}
def __init__( def __init__(
self, self,
render_mode: Optional[str] = None,
continuous: bool = False, continuous: bool = False,
gravity: float = -10.0, gravity: float = -10.0,
enable_wind: bool = False, enable_wind: bool = False,
@@ -267,6 +272,10 @@ class LunarLander(gym.Env, EzPickle):
# Nop, fire left engine, main engine, right engine # Nop, fire left engine, main engine, right engine
self.action_space = spaces.Discrete(4) self.action_space = spaces.Discrete(4)
assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
def _destroy(self): def _destroy(self):
if not self.moon: if not self.moon:
return return
@@ -390,6 +399,7 @@ class LunarLander(gym.Env, EzPickle):
self.drawlist = [self.lander] + self.legs self.drawlist = [self.lander] + self.legs
self.renderer.reset()
if not return_info: if not return_info:
return self.step(np.array([0, 0]) if self.continuous else 0)[0] return self.step(np.array([0, 0]) if self.continuous else 0)[0]
else: else:
@@ -567,9 +577,18 @@ class LunarLander(gym.Env, EzPickle):
if not self.lander.awake: if not self.lander.awake:
done = True done = True
reward = +100 reward = +100
self.renderer.render_step()
return np.array(state, dtype=np.float32), reward, done, {} return np.array(state, dtype=np.float32), reward, done, {}
def render(self, mode="human"): def render(self, mode="human"):
if self.render_mode is not None:
return self.renderer.get_renders()
else:
return self._render(mode)
def _render(self, mode="human"):
assert mode in self.metadata["render_modes"]
try: try:
import pygame import pygame
from pygame import gfxdraw from pygame import gfxdraw
@@ -578,14 +597,14 @@ class LunarLander(gym.Env, EzPickle):
"pygame is not installed, run `pip install gym[box2d]`" "pygame is not installed, run `pip install gym[box2d]`"
) )
if self.screen is None: if self.screen is None and mode == "human":
pygame.init() pygame.init()
pygame.display.init() pygame.display.init()
self.screen = pygame.display.set_mode((VIEWPORT_W, VIEWPORT_H)) self.screen = pygame.display.set_mode((VIEWPORT_W, VIEWPORT_H))
if self.clock is None: if self.clock is None:
self.clock = pygame.time.Clock() self.clock = pygame.time.Clock()
self.surf = pygame.Surface(self.screen.get_size()) self.surf = pygame.Surface((VIEWPORT_W, VIEWPORT_H))
pygame.transform.scale(self.surf, (SCALE, SCALE)) pygame.transform.scale(self.surf, (SCALE, SCALE))
pygame.draw.rect(self.surf, (255, 255, 255), self.surf.get_rect()) pygame.draw.rect(self.surf, (255, 255, 255), self.surf.get_rect())
@@ -664,19 +683,16 @@ class LunarLander(gym.Env, EzPickle):
) )
self.surf = pygame.transform.flip(self.surf, False, True) self.surf = pygame.transform.flip(self.surf, False, True)
self.screen.blit(self.surf, (0, 0))
if mode == "human": if mode == "human":
self.screen.blit(self.surf, (0, 0))
pygame.event.pump() pygame.event.pump()
self.clock.tick(self.metadata["render_fps"]) self.clock.tick(self.metadata["render_fps"])
pygame.display.flip() pygame.display.flip()
elif mode in {"rgb_array", "single_rgb_array"}:
if mode == "rgb_array":
return np.transpose( return np.transpose(
np.array(pygame.surfarray.pixels3d(self.surf)), axes=(1, 0, 2) np.array(pygame.surfarray.pixels3d(self.surf)), axes=(1, 0, 2)
) )
else:
return self.isopen
def close(self): def close(self):
if self.screen is not None: if self.screen is not None:

View File

@@ -20,6 +20,7 @@ __author__ = "Christoph Dann <cdann@cdann.de>"
# SOURCE: # SOURCE:
# https://github.com/rlpy/rlpy/blob/master/rlpy/Domains/Acrobot.py # https://github.com/rlpy/rlpy/blob/master/rlpy/Domains/Acrobot.py
from gym.utils.renderer import Renderer
class AcrobotEnv(core.Env): class AcrobotEnv(core.Env):
@@ -134,7 +135,10 @@ class AcrobotEnv(core.Env):
- Sutton, R. S., Barto, A. G. (2018 ). Reinforcement Learning: An Introduction. The MIT Press. - Sutton, R. S., Barto, A. G. (2018 ). Reinforcement Learning: An Introduction. The MIT Press.
""" """
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 15} metadata = {
"render_modes": ["human", "rgb_array", "single_rgb_array"],
"render_fps": 15,
}
dt = 0.2 dt = 0.2
@@ -161,7 +165,10 @@ class AcrobotEnv(core.Env):
domain_fig = None domain_fig = None
actions_num = 3 actions_num = 3
def __init__(self): def __init__(self, render_mode: Optional[str] = None):
assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
self.screen = None self.screen = None
self.clock = None self.clock = None
self.isopen = True self.isopen = True
@@ -184,6 +191,9 @@ class AcrobotEnv(core.Env):
self.state = self.np_random.uniform(low=-0.1, high=0.1, size=(4,)).astype( self.state = self.np_random.uniform(low=-0.1, high=0.1, size=(4,)).astype(
np.float32 np.float32
) )
self.renderer.reset()
self.renderer.render_step()
if not return_info: if not return_info:
return self._get_ob() return self._get_ob()
else: else:
@@ -213,7 +223,9 @@ class AcrobotEnv(core.Env):
self.state = ns self.state = ns
terminal = self._terminal() terminal = self._terminal()
reward = -1.0 if not terminal else 0.0 reward = -1.0 if not terminal else 0.0
return (self._get_ob(), reward, terminal, {})
self.renderer.render_step()
return self._get_ob(), reward, terminal, {}
def _get_ob(self): def _get_ob(self):
s = self.state s = self.state
@@ -267,9 +279,16 @@ class AcrobotEnv(core.Env):
a + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1**2 * sin(theta2) - phi2 a + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1**2 * sin(theta2) - phi2
) / (m2 * lc2**2 + I2 - d2**2 / d1) ) / (m2 * lc2**2 + I2 - d2**2 / d1)
ddtheta1 = -(d2 * ddtheta2 + phi1) / d1 ddtheta1 = -(d2 * ddtheta2 + phi1) / d1
return (dtheta1, dtheta2, ddtheta1, ddtheta2, 0.0) return dtheta1, dtheta2, ddtheta1, ddtheta2, 0.0
def render(self, mode="human"): def render(self, mode="human"):
if self.render_mode is not None:
return self.renderer.get_renders()
else:
return self._render(mode)
def _render(self, mode="human"):
assert mode in self.metadata["render_modes"]
try: try:
import pygame import pygame
from pygame import gfxdraw from pygame import gfxdraw
@@ -280,13 +299,18 @@ class AcrobotEnv(core.Env):
if self.screen is None: if self.screen is None:
pygame.init() pygame.init()
pygame.display.init() if mode == "human":
self.screen = pygame.display.set_mode((self.SCREEN_DIM, self.SCREEN_DIM)) pygame.display.init()
self.screen = pygame.display.set_mode(
(self.SCREEN_DIM, self.SCREEN_DIM)
)
else: # mode in {"rgb_array", "single_rgb_array"}
self.screen = pygame.Surface((self.SCREEN_DIM, self.SCREEN_DIM))
if self.clock is None: if self.clock is None:
self.clock = pygame.time.Clock() self.clock = pygame.time.Clock()
self.surf = pygame.Surface((self.SCREEN_DIM, self.SCREEN_DIM)) surf = pygame.Surface((self.SCREEN_DIM, self.SCREEN_DIM))
self.surf.fill((255, 255, 255)) surf.fill((255, 255, 255))
s = self.state s = self.state
bound = self.LINK_LENGTH_1 + self.LINK_LENGTH_2 + 0.2 # 2.2 for default bound = self.LINK_LENGTH_1 + self.LINK_LENGTH_2 + 0.2 # 2.2 for default
@@ -311,7 +335,7 @@ class AcrobotEnv(core.Env):
link_lengths = [self.LINK_LENGTH_1 * scale, self.LINK_LENGTH_2 * scale] link_lengths = [self.LINK_LENGTH_1 * scale, self.LINK_LENGTH_2 * scale]
pygame.draw.line( pygame.draw.line(
self.surf, surf,
start_pos=(-2.2 * scale + offset, 1 * scale + offset), start_pos=(-2.2 * scale + offset, 1 * scale + offset),
end_pos=(2.2 * scale + offset, 1 * scale + offset), end_pos=(2.2 * scale + offset, 1 * scale + offset),
color=(0, 0, 0), color=(0, 0, 0),
@@ -327,35 +351,33 @@ class AcrobotEnv(core.Env):
coord = pygame.math.Vector2(coord).rotate_rad(th) coord = pygame.math.Vector2(coord).rotate_rad(th)
coord = (coord[0] + x, coord[1] + y) coord = (coord[0] + x, coord[1] + y)
transformed_coords.append(coord) transformed_coords.append(coord)
gfxdraw.aapolygon(self.surf, transformed_coords, (0, 204, 204)) gfxdraw.aapolygon(surf, transformed_coords, (0, 204, 204))
gfxdraw.filled_polygon(self.surf, transformed_coords, (0, 204, 204)) gfxdraw.filled_polygon(surf, transformed_coords, (0, 204, 204))
gfxdraw.aacircle(self.surf, int(x), int(y), int(0.1 * scale), (204, 204, 0)) gfxdraw.aacircle(surf, int(x), int(y), int(0.1 * scale), (204, 204, 0))
gfxdraw.filled_circle( gfxdraw.filled_circle(surf, int(x), int(y), int(0.1 * scale), (204, 204, 0))
self.surf, int(x), int(y), int(0.1 * scale), (204, 204, 0)
) surf = pygame.transform.flip(surf, False, True)
self.screen.blit(surf, (0, 0))
self.surf = pygame.transform.flip(self.surf, False, True)
self.screen.blit(self.surf, (0, 0))
if mode == "human": if mode == "human":
pygame.event.pump() pygame.event.pump()
self.clock.tick(self.metadata["render_fps"]) self.clock.tick(self.metadata["render_fps"])
pygame.display.flip() pygame.display.flip()
if mode == "rgb_array": elif mode in {"rgb_array", "single_rgb_array"}:
return np.transpose( return np.transpose(
np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2) np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)
) )
else:
return self.isopen
def close(self):
if self.screen is not None:
import pygame
pygame.display.quit() def close(self):
pygame.quit() if self.screen is not None:
self.isopen = False import pygame
pygame.display.quit()
pygame.quit()
self.isopen = False
def wrap(x, m, M): def wrap(x, m, M):

View File

@@ -11,6 +11,7 @@ import numpy as np
import gym import gym
from gym import logger, spaces from gym import logger, spaces
from gym.error import DependencyNotInstalled from gym.error import DependencyNotInstalled
from gym.utils.renderer import Renderer
class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]): class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
@@ -79,9 +80,12 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
No additional arguments are currently supported. No additional arguments are currently supported.
""" """
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 50} metadata = {
"render_modes": ["human", "rgb_array", "single_rgb_array"],
"render_fps": 50,
}
def __init__(self): def __init__(self, render_mode: Optional[str] = None):
self.gravity = 9.8 self.gravity = 9.8
self.masscart = 1.0 self.masscart = 1.0
self.masspole = 0.1 self.masspole = 0.1
@@ -111,6 +115,12 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
self.action_space = spaces.Discrete(2) self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Box(-high, high, dtype=np.float32) self.observation_space = spaces.Box(-high, high, dtype=np.float32)
assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
self.screen_width = 600
self.screen_height = 400
self.screen = None self.screen = None
self.clock = None self.clock = None
self.isopen = True self.isopen = True
@@ -174,6 +184,7 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
self.steps_beyond_done += 1 self.steps_beyond_done += 1
reward = 0.0 reward = 0.0
self.renderer.render_step()
return np.array(self.state, dtype=np.float32), reward, done, {} return np.array(self.state, dtype=np.float32), reward, done, {}
def reset( def reset(
@@ -186,12 +197,21 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
super().reset(seed=seed) super().reset(seed=seed)
self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,)) self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4,))
self.steps_beyond_done = None self.steps_beyond_done = None
self.renderer.reset()
self.renderer.render_step()
if not return_info: if not return_info:
return np.array(self.state, dtype=np.float32) return np.array(self.state, dtype=np.float32)
else: else:
return np.array(self.state, dtype=np.float32), {} return np.array(self.state, dtype=np.float32), {}
def render(self, mode="human"): def render(self, mode="human"):
if self.render_mode is not None:
return self.renderer.get_renders()
else:
return self._render(mode)
def _render(self, mode="human"):
assert mode in self.metadata["render_modes"]
try: try:
import pygame import pygame
from pygame import gfxdraw from pygame import gfxdraw
@@ -200,11 +220,20 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
"pygame is not installed, run `pip install gym[classic_control]`" "pygame is not installed, run `pip install gym[classic_control]`"
) )
screen_width = 600 if self.screen is None:
screen_height = 400 pygame.init()
if mode == "human":
pygame.display.init()
self.screen = pygame.display.set_mode(
(self.screen_width, self.screen_height)
)
else: # mode in {"rgb_array", "single_rgb_array"}
self.screen = pygame.Surface((self.screen_width, self.screen_height))
if self.clock is None:
self.clock = pygame.time.Clock()
world_width = self.x_threshold * 2 world_width = self.x_threshold * 2
scale = screen_width / world_width scale = self.screen_width / world_width
polewidth = 10.0 polewidth = 10.0
polelen = scale * (2 * self.length) polelen = scale * (2 * self.length)
cartwidth = 50.0 cartwidth = 50.0
@@ -215,19 +244,12 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
x = self.state x = self.state
if self.screen is None: self.surf = pygame.Surface((self.screen_width, self.screen_height))
pygame.init()
pygame.display.init()
self.screen = pygame.display.set_mode((screen_width, screen_height))
if self.clock is None:
self.clock = pygame.time.Clock()
self.surf = pygame.Surface((screen_width, screen_height))
self.surf.fill((255, 255, 255)) self.surf.fill((255, 255, 255))
l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2 l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2
axleoffset = cartheight / 4.0 axleoffset = cartheight / 4.0
cartx = x[0] * scale + screen_width / 2.0 # MIDDLE OF CART cartx = x[0] * scale + self.screen_width / 2.0 # MIDDLE OF CART
carty = 100 # TOP OF CART carty = 100 # TOP OF CART
cart_coords = [(l, b), (l, t), (r, t), (r, b)] cart_coords = [(l, b), (l, t), (r, t), (r, b)]
cart_coords = [(c[0] + cartx, c[1] + carty) for c in cart_coords] cart_coords = [(c[0] + cartx, c[1] + carty) for c in cart_coords]
@@ -264,7 +286,7 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
(129, 132, 203), (129, 132, 203),
) )
gfxdraw.hline(self.surf, 0, screen_width, carty, (0, 0, 0)) gfxdraw.hline(self.surf, 0, self.screen_width, carty, (0, 0, 0))
self.surf = pygame.transform.flip(self.surf, False, True) self.surf = pygame.transform.flip(self.surf, False, True)
self.screen.blit(self.surf, (0, 0)) self.screen.blit(self.surf, (0, 0))
@@ -273,12 +295,10 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
self.clock.tick(self.metadata["render_fps"]) self.clock.tick(self.metadata["render_fps"])
pygame.display.flip() pygame.display.flip()
if mode == "rgb_array": elif mode in {"rgb_array", "single_rgb_array"}:
return np.transpose( return np.transpose(
np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2) np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)
) )
else:
return self.isopen
def close(self): def close(self):
if self.screen is not None: if self.screen is not None:

View File

@@ -21,6 +21,7 @@ import numpy as np
import gym import gym
from gym import spaces from gym import spaces
from gym.error import DependencyNotInstalled from gym.error import DependencyNotInstalled
from gym.utils.renderer import Renderer
class Continuous_MountainCarEnv(gym.Env): class Continuous_MountainCarEnv(gym.Env):
@@ -99,9 +100,12 @@ class Continuous_MountainCarEnv(gym.Env):
* v0: Initial versions release (1.0.0) * v0: Initial versions release (1.0.0)
""" """
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30} metadata = {
"render_modes": ["human", "rgb_array", "single_rgb_array"],
"render_fps": 30,
}
def __init__(self, goal_velocity=0): def __init__(self, render_mode: Optional[str] = None, goal_velocity=0):
self.min_action = -1.0 self.min_action = -1.0
self.max_action = 1.0 self.max_action = 1.0
self.min_position = -1.2 self.min_position = -1.2
@@ -120,6 +124,12 @@ class Continuous_MountainCarEnv(gym.Env):
[self.max_position, self.max_speed], dtype=np.float32 [self.max_position, self.max_speed], dtype=np.float32
) )
assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
self.screen_width = 600
self.screen_height = 400
self.screen = None self.screen = None
self.clock = None self.clock = None
self.isopen = True self.isopen = True
@@ -159,6 +169,8 @@ class Continuous_MountainCarEnv(gym.Env):
reward -= math.pow(action[0], 2) * 0.1 reward -= math.pow(action[0], 2) * 0.1
self.state = np.array([position, velocity], dtype=np.float32) self.state = np.array([position, velocity], dtype=np.float32)
self.renderer.render_step()
return self.state, reward, done, {} return self.state, reward, done, {}
def reset( def reset(
@@ -170,6 +182,8 @@ class Continuous_MountainCarEnv(gym.Env):
): ):
super().reset(seed=seed) super().reset(seed=seed)
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0]) self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
self.renderer.reset()
self.renderer.render_step()
if not return_info: if not return_info:
return np.array(self.state, dtype=np.float32) return np.array(self.state, dtype=np.float32)
else: else:
@@ -179,6 +193,14 @@ class Continuous_MountainCarEnv(gym.Env):
return np.sin(3 * xs) * 0.45 + 0.55 return np.sin(3 * xs) * 0.45 + 0.55
def render(self, mode="human"): def render(self, mode="human"):
if self.render_mode is not None:
return self.renderer.get_renders()
else:
return self._render(mode)
def _render(self, mode="human"):
assert mode in self.metadata["render_modes"]
try: try:
import pygame import pygame
from pygame import gfxdraw from pygame import gfxdraw
@@ -187,21 +209,24 @@ class Continuous_MountainCarEnv(gym.Env):
"pygame is not installed, run `pip install gym[classic_control]`" "pygame is not installed, run `pip install gym[classic_control]`"
) )
screen_width = 600
screen_height = 400
world_width = self.max_position - self.min_position
scale = screen_width / world_width
carwidth = 40
carheight = 20
if self.screen is None: if self.screen is None:
pygame.init() pygame.init()
pygame.display.init() if mode == "human":
self.screen = pygame.display.set_mode((screen_width, screen_height)) pygame.display.init()
self.screen = pygame.display.set_mode(
(self.screen_width, self.screen_height)
)
else: # mode in {"rgb_array", "single_rgb_array"}
self.screen = pygame.Surface((self.screen_width, self.screen_height))
if self.clock is None: if self.clock is None:
self.clock = pygame.time.Clock() self.clock = pygame.time.Clock()
self.surf = pygame.Surface((screen_width, screen_height)) world_width = self.max_position - self.min_position
scale = self.screen_width / world_width
carwidth = 40
carheight = 20
self.surf = pygame.Surface((self.screen_width, self.screen_height))
self.surf.fill((255, 255, 255)) self.surf.fill((255, 255, 255))
pos = self.state[0] pos = self.state[0]
@@ -265,12 +290,10 @@ class Continuous_MountainCarEnv(gym.Env):
self.clock.tick(self.metadata["render_fps"]) self.clock.tick(self.metadata["render_fps"])
pygame.display.flip() pygame.display.flip()
if mode == "rgb_array": elif mode in {"rgb_array", "single_rgb_array"}:
return np.transpose( return np.transpose(
np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2) np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)
) )
else:
return self.isopen
def close(self): def close(self):
if self.screen is not None: if self.screen is not None:

View File

@@ -10,6 +10,7 @@ import numpy as np
import gym import gym
from gym import spaces from gym import spaces
from gym.error import DependencyNotInstalled from gym.error import DependencyNotInstalled
from gym.utils.renderer import Renderer
class MountainCarEnv(gym.Env): class MountainCarEnv(gym.Env):
@@ -94,9 +95,12 @@ class MountainCarEnv(gym.Env):
* v0: Initial versions release (1.0.0) * v0: Initial versions release (1.0.0)
""" """
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30} metadata = {
"render_modes": ["human", "rgb_array", "single_rgb_array"],
"render_fps": 30,
}
def __init__(self, goal_velocity=0): def __init__(self, render_mode: Optional[str] = None, goal_velocity=0):
self.min_position = -1.2 self.min_position = -1.2
self.max_position = 0.6 self.max_position = 0.6
self.max_speed = 0.07 self.max_speed = 0.07
@@ -109,6 +113,12 @@ class MountainCarEnv(gym.Env):
self.low = np.array([self.min_position, -self.max_speed], dtype=np.float32) self.low = np.array([self.min_position, -self.max_speed], dtype=np.float32)
self.high = np.array([self.max_position, self.max_speed], dtype=np.float32) self.high = np.array([self.max_position, self.max_speed], dtype=np.float32)
assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
self.screen_width = 600
self.screen_height = 400
self.screen = None self.screen = None
self.clock = None self.clock = None
self.isopen = True self.isopen = True
@@ -133,6 +143,8 @@ class MountainCarEnv(gym.Env):
reward = -1.0 reward = -1.0
self.state = (position, velocity) self.state = (position, velocity)
self.renderer.render_step()
return np.array(self.state, dtype=np.float32), reward, done, {} return np.array(self.state, dtype=np.float32), reward, done, {}
def reset( def reset(
@@ -144,6 +156,8 @@ class MountainCarEnv(gym.Env):
): ):
super().reset(seed=seed) super().reset(seed=seed)
self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0]) self.state = np.array([self.np_random.uniform(low=-0.6, high=-0.4), 0])
self.renderer.reset()
self.renderer.render_step()
if not return_info: if not return_info:
return np.array(self.state, dtype=np.float32) return np.array(self.state, dtype=np.float32)
else: else:
@@ -153,6 +167,13 @@ class MountainCarEnv(gym.Env):
return np.sin(3 * xs) * 0.45 + 0.55 return np.sin(3 * xs) * 0.45 + 0.55
def render(self, mode="human"): def render(self, mode="human"):
if self.render_mode is not None:
return self.renderer.get_renders()
else:
return self._render(mode)
def _render(self, mode="human"):
assert mode in self.metadata["render_modes"]
try: try:
import pygame import pygame
from pygame import gfxdraw from pygame import gfxdraw
@@ -161,21 +182,24 @@ class MountainCarEnv(gym.Env):
"pygame is not installed, run `pip install gym[classic_control]`" "pygame is not installed, run `pip install gym[classic_control]`"
) )
screen_width = 600
screen_height = 400
world_width = self.max_position - self.min_position
scale = screen_width / world_width
carwidth = 40
carheight = 20
if self.screen is None: if self.screen is None:
pygame.init() pygame.init()
pygame.display.init() if mode == "human":
self.screen = pygame.display.set_mode((screen_width, screen_height)) pygame.display.init()
self.screen = pygame.display.set_mode(
(self.screen_width, self.screen_height)
)
else: # mode in {"rgb_array", "single_rgb_array"}
self.screen = pygame.Surface((self.screen_width, self.screen_height))
if self.clock is None: if self.clock is None:
self.clock = pygame.time.Clock() self.clock = pygame.time.Clock()
self.surf = pygame.Surface((screen_width, screen_height)) world_width = self.max_position - self.min_position
scale = self.screen_width / world_width
carwidth = 40
carheight = 20
self.surf = pygame.Surface((self.screen_width, self.screen_height))
self.surf.fill((255, 255, 255)) self.surf.fill((255, 255, 255))
pos = self.state[0] pos = self.state[0]
@@ -239,12 +263,10 @@ class MountainCarEnv(gym.Env):
self.clock.tick(self.metadata["render_fps"]) self.clock.tick(self.metadata["render_fps"])
pygame.display.flip() pygame.display.flip()
if mode == "rgb_array": elif mode in {"rgb_array", "single_rgb_array"}:
return np.transpose( return np.transpose(
np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2) np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)
) )
else:
return self.isopen
def get_keys_to_action(self): def get_keys_to_action(self):
# Control with left and right arrow keys. # Control with left and right arrow keys.

View File

@@ -8,6 +8,7 @@ import numpy as np
import gym import gym
from gym import spaces from gym import spaces
from gym.error import DependencyNotInstalled from gym.error import DependencyNotInstalled
from gym.utils.renderer import Renderer
class PendulumEnv(gym.Env): class PendulumEnv(gym.Env):
@@ -83,21 +84,28 @@ class PendulumEnv(gym.Env):
""" """
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30} metadata = {
"render_modes": ["human", "rgb_array", "single_rgb_array"],
"render_fps": 30,
}
def __init__(self, g=10.0): def __init__(self, render_mode: Optional[str] = None, g=10.0):
self.max_speed = 8 self.max_speed = 8
self.max_torque = 2.0 self.max_torque = 2.0
self.dt = 0.05 self.dt = 0.05
self.g = g self.g = g
self.m = 1.0 self.m = 1.0
self.l = 1.0 self.l = 1.0
assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
self.screen_dim = 500
self.screen = None self.screen = None
self.clock = None self.clock = None
self.isopen = True self.isopen = True
self.screen_dim = 500
high = np.array([1.0, 1.0, self.max_speed], dtype=np.float32) high = np.array([1.0, 1.0, self.max_speed], dtype=np.float32)
# This will throw a warning in tests/envs/test_envs in utils/env_checker.py as the space is not symmetric # This will throw a warning in tests/envs/test_envs in utils/env_checker.py as the space is not symmetric
# or normalised as max_torque == 2 by default. Ignoring the issue here as the default settings are too old # or normalised as max_torque == 2 by default. Ignoring the issue here as the default settings are too old
@@ -124,6 +132,7 @@ class PendulumEnv(gym.Env):
newth = th + newthdot * dt newth = th + newthdot * dt
self.state = np.array([newth, newthdot]) self.state = np.array([newth, newthdot])
self.renderer.render_step()
return self._get_obs(), -costs, False, {} return self._get_obs(), -costs, False, {}
def reset( def reset(
@@ -137,6 +146,9 @@ class PendulumEnv(gym.Env):
high = np.array([np.pi, 1]) high = np.array([np.pi, 1])
self.state = self.np_random.uniform(low=-high, high=high) self.state = self.np_random.uniform(low=-high, high=high)
self.last_u = None self.last_u = None
self.renderer.reset()
self.renderer.render_step()
if not return_info: if not return_info:
return self._get_obs() return self._get_obs()
else: else:
@@ -147,6 +159,13 @@ class PendulumEnv(gym.Env):
return np.array([np.cos(theta), np.sin(theta), thetadot], dtype=np.float32) return np.array([np.cos(theta), np.sin(theta), thetadot], dtype=np.float32)
def render(self, mode="human"): def render(self, mode="human"):
if self.render_mode is not None:
return self.renderer.get_renders()
else:
return self._render(mode)
def _render(self, mode="human"):
assert mode in self.metadata["render_modes"]
try: try:
import pygame import pygame
from pygame import gfxdraw from pygame import gfxdraw
@@ -157,8 +176,13 @@ class PendulumEnv(gym.Env):
if self.screen is None: if self.screen is None:
pygame.init() pygame.init()
pygame.display.init() if mode == "human":
self.screen = pygame.display.set_mode((self.screen_dim, self.screen_dim)) pygame.display.init()
self.screen = pygame.display.set_mode(
(self.screen_dim, self.screen_dim)
)
else: # mode in {"rgb_array", "single_rgb_array"}
self.screen = pygame.Surface((self.screen_dim, self.screen_dim))
if self.clock is None: if self.clock is None:
self.clock = pygame.time.Clock() self.clock = pygame.time.Clock()
@@ -200,7 +224,8 @@ class PendulumEnv(gym.Env):
img = pygame.image.load(fname) img = pygame.image.load(fname)
if self.last_u is not None: if self.last_u is not None:
scale_img = pygame.transform.smoothscale( scale_img = pygame.transform.smoothscale(
img, (scale * np.abs(self.last_u) / 2, scale * np.abs(self.last_u) / 2) img,
(scale * np.abs(self.last_u) / 2, scale * np.abs(self.last_u) / 2),
) )
is_flip = bool(self.last_u > 0) is_flip = bool(self.last_u > 0)
scale_img = pygame.transform.flip(scale_img, is_flip, True) scale_img = pygame.transform.flip(scale_img, is_flip, True)
@@ -223,12 +248,10 @@ class PendulumEnv(gym.Env):
self.clock.tick(self.metadata["render_fps"]) self.clock.tick(self.metadata["render_fps"])
pygame.display.flip() pygame.display.flip()
if mode == "rgb_array": else: # mode == "rgb_array":
return np.transpose( return np.transpose(
np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2) np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)
) )
else:
return self.isopen
def close(self): def close(self):
if self.screen is not None: if self.screen is not None:

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -5,14 +7,19 @@ from gym.envs.mujoco import mujoco_env
class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self): def __init__(self, render_mode: Optional[str] = None):
mujoco_env.MujocoEnv.__init__(self, "ant.xml", 5, mujoco_bindings="mujoco_py") mujoco_env.MujocoEnv.__init__(
self, "ant.xml", 5, render_mode=render_mode, mujoco_bindings="mujoco_py"
)
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
def step(self, a): def step(self, a):
xposbefore = self.get_body_com("torso")[0] xposbefore = self.get_body_com("torso")[0]
self.do_simulation(a, self.frame_skip) self.do_simulation(a, self.frame_skip)
xposafter = self.get_body_com("torso")[0] xposafter = self.get_body_com("torso")[0]
self.renderer.render_step()
forward_reward = (xposafter - xposbefore) / self.dt forward_reward = (xposafter - xposbefore) / self.dt
ctrl_cost = 0.5 * np.square(a).sum() ctrl_cost = 0.5 * np.square(a).sum()
contact_cost = ( contact_cost = (

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -11,6 +13,7 @@ DEFAULT_CAMERA_CONFIG = {
class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__( def __init__(
self, self,
render_mode: Optional[str] = None,
xml_file="ant.xml", xml_file="ant.xml",
ctrl_cost_weight=0.5, ctrl_cost_weight=0.5,
contact_cost_weight=5e-4, contact_cost_weight=5e-4,
@@ -94,6 +97,8 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
rewards = forward_reward + healthy_reward rewards = forward_reward + healthy_reward
costs = ctrl_cost + contact_cost costs = ctrl_cost + contact_cost
self.renderer.render_step()
reward = rewards - costs reward = rewards - costs
done = self.done done = self.done
observation = self._get_obs() observation = self._get_obs()

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -164,6 +166,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__( def __init__(
self, self,
render_mode: Optional[str] = None,
xml_file="ant.xml", xml_file="ant.xml",
ctrl_cost_weight=0.5, ctrl_cost_weight=0.5,
use_contact_forces=False, use_contact_forces=False,
@@ -194,7 +197,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
mujoco_env.MujocoEnv.__init__(self, xml_file, 5) mujoco_env.MujocoEnv.__init__(self, xml_file, 5, render_mode=render_mode)
@property @property
def healthy_reward(self): def healthy_reward(self):
@@ -268,6 +271,7 @@ class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
reward = rewards - costs reward = rewards - costs
self.renderer.render_step()
return observation, reward, done, info return observation, reward, done, info
def _get_obs(self): def _get_obs(self):

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -5,9 +7,13 @@ from gym.envs.mujoco import mujoco_env
class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle): class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self): def __init__(self, render_mode: Optional[str] = None):
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, "half_cheetah.xml", 5, mujoco_bindings="mujoco_py" self,
"half_cheetah.xml",
5,
render_mode=render_mode,
mujoco_bindings="mujoco_py",
) )
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
@@ -15,6 +21,9 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
xposbefore = self.sim.data.qpos[0] xposbefore = self.sim.data.qpos[0]
self.do_simulation(action, self.frame_skip) self.do_simulation(action, self.frame_skip)
xposafter = self.sim.data.qpos[0] xposafter = self.sim.data.qpos[0]
self.renderer.render_step()
ob = self._get_obs() ob = self._get_obs()
reward_ctrl = -0.1 * np.square(action).sum() reward_ctrl = -0.1 * np.square(action).sum()
reward_run = (xposafter - xposbefore) / self.dt reward_run = (xposafter - xposbefore) / self.dt

View File

@@ -1,4 +1,7 @@
__credits__ = ["Rushiv Arora"] __credits__ = ["Rushiv Arora"]
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -12,6 +15,7 @@ DEFAULT_CAMERA_CONFIG = {
class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle): class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__( def __init__(
self, self,
render_mode: Optional[str] = None,
xml_file="half_cheetah.xml", xml_file="half_cheetah.xml",
forward_reward_weight=1.0, forward_reward_weight=1.0,
ctrl_cost_weight=0.1, ctrl_cost_weight=0.1,
@@ -30,7 +34,9 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
mujoco_env.MujocoEnv.__init__(self, xml_file, 5, mujoco_bindings="mujoco_py") mujoco_env.MujocoEnv.__init__(
self, xml_file, 5, render_mode=render_mode, mujoco_bindings="mujoco_py"
)
def control_cost(self, action): def control_cost(self, action):
control_cost = self._ctrl_cost_weight * np.sum(np.square(action)) control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
@@ -46,6 +52,8 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
forward_reward = self._forward_reward_weight * x_velocity forward_reward = self._forward_reward_weight * x_velocity
self.renderer.render_step()
observation = self._get_obs() observation = self._get_obs()
reward = forward_reward - ctrl_cost reward = forward_reward - ctrl_cost
done = False done = False

View File

@@ -1,4 +1,7 @@
__credits__ = ["Rushiv Arora"] __credits__ = ["Rushiv Arora"]
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -148,6 +151,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__( def __init__(
self, self,
render_mode: Optional[str] = None,
xml_file="half_cheetah.xml", xml_file="half_cheetah.xml",
forward_reward_weight=1.0, forward_reward_weight=1.0,
ctrl_cost_weight=0.1, ctrl_cost_weight=0.1,
@@ -166,7 +170,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
mujoco_env.MujocoEnv.__init__(self, xml_file, 5) mujoco_env.MujocoEnv.__init__(self, xml_file, 5, render_mode=render_mode)
def control_cost(self, action): def control_cost(self, action):
control_cost = self._ctrl_cost_weight * np.sum(np.square(action)) control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
@@ -192,6 +196,7 @@ class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle):
"reward_ctrl": -ctrl_cost, "reward_ctrl": -ctrl_cost,
} }
self.renderer.render_step()
return observation, reward, done, info return observation, reward, done, info
def _get_obs(self): def _get_obs(self):

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -5,9 +7,9 @@ from gym.envs.mujoco import mujoco_env
class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle): class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self): def __init__(self, render_mode: Optional[str] = None):
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, "hopper.xml", 4, mujoco_bindings="mujoco_py" self, "hopper.xml", 4, render_mode=render_mode, mujoco_bindings="mujoco_py"
) )
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
@@ -15,6 +17,9 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
posbefore = self.sim.data.qpos[0] posbefore = self.sim.data.qpos[0]
self.do_simulation(a, self.frame_skip) self.do_simulation(a, self.frame_skip)
posafter, height, ang = self.sim.data.qpos[0:3] posafter, height, ang = self.sim.data.qpos[0:3]
self.renderer.render_step()
alive_bonus = 1.0 alive_bonus = 1.0
reward = (posafter - posbefore) / self.dt reward = (posafter - posbefore) / self.dt
reward += alive_bonus reward += alive_bonus

View File

@@ -1,5 +1,7 @@
__credits__ = ["Rushiv Arora"] __credits__ = ["Rushiv Arora"]
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -16,6 +18,7 @@ DEFAULT_CAMERA_CONFIG = {
class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle): class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__( def __init__(
self, self,
render_mode: Optional[str] = None,
xml_file="hopper.xml", xml_file="hopper.xml",
forward_reward_weight=1.0, forward_reward_weight=1.0,
ctrl_cost_weight=1e-3, ctrl_cost_weight=1e-3,
@@ -46,7 +49,9 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
mujoco_env.MujocoEnv.__init__(self, xml_file, 4, mujoco_bindings="mujoco_py") mujoco_env.MujocoEnv.__init__(
self, xml_file, 4, render_mode=render_mode, mujoco_bindings="mujoco_py"
)
@property @property
def healthy_reward(self): def healthy_reward(self):
@@ -105,6 +110,8 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
rewards = forward_reward + healthy_reward rewards = forward_reward + healthy_reward
costs = ctrl_cost costs = ctrl_cost
self.renderer.render_step()
observation = self._get_obs() observation = self._get_obs()
reward = rewards - costs reward = rewards - costs
done = self.done done = self.done

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -140,6 +142,7 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__( def __init__(
self, self,
render_mode: Optional[str] = None,
xml_file="hopper.xml", xml_file="hopper.xml",
forward_reward_weight=1.0, forward_reward_weight=1.0,
ctrl_cost_weight=1e-3, ctrl_cost_weight=1e-3,
@@ -170,7 +173,7 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
mujoco_env.MujocoEnv.__init__(self, xml_file, 4) mujoco_env.MujocoEnv.__init__(self, xml_file, 4, render_mode=render_mode)
@property @property
def healthy_reward(self): def healthy_reward(self):
@@ -237,6 +240,7 @@ class HopperEnv(mujoco_env.MujocoEnv, utils.EzPickle):
"x_velocity": x_velocity, "x_velocity": x_velocity,
} }
self.renderer.render_step()
return observation, reward, done, info return observation, reward, done, info
def reset_model(self): def reset_model(self):

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -11,9 +13,13 @@ def mass_center(model, sim):
class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle): class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self): def __init__(self, render_mode: Optional[str] = None):
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, "humanoid.xml", 5, mujoco_bindings="mujoco_py" self,
"humanoid.xml",
5,
render_mode=render_mode,
mujoco_bindings="mujoco_py",
) )
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
@@ -34,6 +40,9 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
pos_before = mass_center(self.model, self.sim) pos_before = mass_center(self.model, self.sim)
self.do_simulation(a, self.frame_skip) self.do_simulation(a, self.frame_skip)
pos_after = mass_center(self.model, self.sim) pos_after = mass_center(self.model, self.sim)
self.renderer.render_step()
alive_bonus = 5.0 alive_bonus = 5.0
data = self.sim.data data = self.sim.data
lin_vel_cost = 1.25 * (pos_after - pos_before) / self.dt lin_vel_cost = 1.25 * (pos_after - pos_before) / self.dt

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -20,6 +22,7 @@ def mass_center(model, sim):
class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle): class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__( def __init__(
self, self,
render_mode: Optional[str] = None,
xml_file="humanoid.xml", xml_file="humanoid.xml",
forward_reward_weight=1.25, forward_reward_weight=1.25,
ctrl_cost_weight=0.1, ctrl_cost_weight=0.1,
@@ -47,7 +50,9 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
mujoco_env.MujocoEnv.__init__(self, xml_file, 5, mujoco_bindings="mujoco_py") mujoco_env.MujocoEnv.__init__(
self, xml_file, 5, render_mode=render_mode, mujoco_bindings="mujoco_py"
)
@property @property
def healthy_reward(self): def healthy_reward(self):
@@ -121,6 +126,8 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
rewards = forward_reward + healthy_reward rewards = forward_reward + healthy_reward
costs = ctrl_cost + contact_cost costs = ctrl_cost + contact_cost
self.renderer.render_step()
observation = self._get_obs() observation = self._get_obs()
reward = rewards - costs reward = rewards - costs
done = self.done done = self.done

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -204,6 +206,7 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__( def __init__(
self, self,
render_mode: Optional[str] = None,
xml_file="humanoid.xml", xml_file="humanoid.xml",
forward_reward_weight=1.25, forward_reward_weight=1.25,
ctrl_cost_weight=0.1, ctrl_cost_weight=0.1,
@@ -227,7 +230,7 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
mujoco_env.MujocoEnv.__init__(self, xml_file, 5) mujoco_env.MujocoEnv.__init__(self, xml_file, 5, render_mode=render_mode)
@property @property
def healthy_reward(self): def healthy_reward(self):
@@ -306,6 +309,7 @@ class HumanoidEnv(mujoco_env.MujocoEnv, utils.EzPickle):
"forward_reward": forward_reward, "forward_reward": forward_reward,
} }
self.renderer.render_step()
return observation, reward, done, info return observation, reward, done, info
def reset_model(self): def reset_model(self):

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -5,9 +7,13 @@ from gym.envs.mujoco import mujoco_env
class HumanoidStandupEnv(mujoco_env.MujocoEnv, utils.EzPickle): class HumanoidStandupEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self): def __init__(self, render_mode: Optional[str] = None):
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, "humanoidstandup.xml", 5, mujoco_bindings="mujoco_py" self,
"humanoidstandup.xml",
5,
render_mode=render_mode,
mujoco_bindings="mujoco_py",
) )
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
@@ -35,6 +41,8 @@ class HumanoidStandupEnv(mujoco_env.MujocoEnv, utils.EzPickle):
quad_impact_cost = min(quad_impact_cost, 10) quad_impact_cost = min(quad_impact_cost, 10)
reward = uph_cost - quad_ctrl_cost - quad_impact_cost + 1 reward = uph_cost - quad_ctrl_cost - quad_impact_cost + 1
self.renderer.render_step()
done = bool(False) done = bool(False)
return ( return (
self._get_obs(), self._get_obs(),

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -190,8 +192,10 @@ class HumanoidStandupEnv(mujoco_env.MujocoEnv, utils.EzPickle):
""" """
def __init__(self): def __init__(self, render_mode: Optional[str] = None):
mujoco_env.MujocoEnv.__init__(self, "humanoidstandup.xml", 5) mujoco_env.MujocoEnv.__init__(
self, "humanoidstandup.xml", 5, render_mode=render_mode
)
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
def _get_obs(self): def _get_obs(self):
@@ -218,6 +222,8 @@ class HumanoidStandupEnv(mujoco_env.MujocoEnv, utils.EzPickle):
quad_impact_cost = min(quad_impact_cost, 10) quad_impact_cost = min(quad_impact_cost, 10)
reward = uph_cost - quad_ctrl_cost - quad_impact_cost + 1 reward = uph_cost - quad_ctrl_cost - quad_impact_cost + 1
self.renderer.render_step()
done = bool(False) done = bool(False)
return ( return (
self._get_obs(), self._get_obs(),

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -5,14 +7,21 @@ from gym.envs.mujoco import mujoco_env
class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle): class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self): def __init__(self, render_mode: Optional[str] = None):
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, "inverted_double_pendulum.xml", 5, mujoco_bindings="mujoco_py" self,
"inverted_double_pendulum.xml",
5,
render_mode=render_mode,
mujoco_bindings="mujoco_py",
) )
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
def step(self, action): def step(self, action):
self.do_simulation(action, self.frame_skip) self.do_simulation(action, self.frame_skip)
self.renderer.render_step()
ob = self._get_obs() ob = self._get_obs()
x, _, y = self.sim.data.site_xpos[0] x, _, y = self.sim.data.site_xpos[0]
dist_penalty = 0.01 * x**2 + (y - 2) ** 2 dist_penalty = 0.01 * x**2 + (y - 2) ** 2

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -118,8 +120,10 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
""" """
def __init__(self): def __init__(self, render_mode: Optional[str] = None):
mujoco_env.MujocoEnv.__init__(self, "inverted_double_pendulum.xml", 5) mujoco_env.MujocoEnv.__init__(
self, "inverted_double_pendulum.xml", 5, render_mode=render_mode
)
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
def step(self, action): def step(self, action):
@@ -132,6 +136,9 @@ class InvertedDoublePendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
alive_bonus = 10 alive_bonus = 10
r = alive_bonus - dist_penalty - vel_penalty r = alive_bonus - dist_penalty - vel_penalty
done = bool(y <= 1) done = bool(y <= 1)
self.renderer.render_step()
return ob, r, done, {} return ob, r, done, {}
def _get_obs(self): def _get_obs(self):

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -5,15 +7,22 @@ from gym.envs.mujoco import mujoco_env
class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle): class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self): def __init__(self, render_mode: Optional[str] = None):
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, "inverted_pendulum.xml", 2, mujoco_bindings="mujoco_py" self,
"inverted_pendulum.xml",
2,
render_mode=render_mode,
mujoco_bindings="mujoco_py",
) )
def step(self, a): def step(self, a):
reward = 1.0 reward = 1.0
self.do_simulation(a, self.frame_skip) self.do_simulation(a, self.frame_skip)
self.renderer.render_step()
ob = self._get_obs() ob = self._get_obs()
notdone = np.isfinite(ob).all() and (np.abs(ob[1]) <= 0.2) notdone = np.isfinite(ob).all() and (np.abs(ob[1]) <= 0.2)
done = not notdone done = not notdone

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -90,9 +92,11 @@ class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
""" """
def __init__(self): def __init__(self, render_mode: Optional[str] = None):
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
mujoco_env.MujocoEnv.__init__(self, "inverted_pendulum.xml", 2) mujoco_env.MujocoEnv.__init__(
self, "inverted_pendulum.xml", 2, render_mode=render_mode
)
def step(self, a): def step(self, a):
reward = 1.0 reward = 1.0
@@ -100,6 +104,9 @@ class InvertedPendulumEnv(mujoco_env.MujocoEnv, utils.EzPickle):
ob = self._get_obs() ob = self._get_obs()
notdone = np.isfinite(ob).all() and (np.abs(ob[1]) <= 0.2) notdone = np.isfinite(ob).all() and (np.abs(ob[1]) <= 0.2)
done = not notdone done = not notdone
self.renderer.render_step()
return ob, reward, done, {} return ob, reward, done, {}
def reset_model(self): def reset_model(self):

View File

@@ -6,6 +6,7 @@ import numpy as np
import gym import gym
from gym import error, logger, spaces from gym import error, logger, spaces
from gym.utils.renderer import Renderer
DEFAULT_SIZE = 480 DEFAULT_SIZE = 480
@@ -33,8 +34,13 @@ def convert_observation_to_space(observation):
class MujocoEnv(gym.Env): class MujocoEnv(gym.Env):
"""Superclass for all MuJoCo environments.""" """Superclass for all MuJoCo environments."""
def __init__(self, model_path, frame_skip, mujoco_bindings="mujoco"): def __init__(
self,
model_path,
frame_skip,
render_mode: Optional[str] = None,
mujoco_bindings="mujoco",
):
if model_path.startswith("/"): if model_path.startswith("/"):
fullpath = model_path fullpath = model_path
else: else:
@@ -87,12 +93,22 @@ class MujocoEnv(gym.Env):
self.viewer = None self.viewer = None
self.metadata = { self.metadata = {
"render_modes": ["human", "rgb_array", "depth_array"], "render_modes": [
"human",
"rgb_array",
"depth_array",
"single_rgb_array",
"single_depth_array",
],
"render_fps": int(np.round(1.0 / self.dt)), "render_fps": int(np.round(1.0 / self.dt)),
} }
self._set_action_space() self._set_action_space()
assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
action = self.action_space.sample() action = self.action_space.sample()
observation, _reward, done, _info = self.step(action) observation, _reward, done, _info = self.step(action)
assert not done assert not done
@@ -142,6 +158,8 @@ class MujocoEnv(gym.Env):
self._mujoco_bindings.mj_resetData(self.model, self.data) self._mujoco_bindings.mj_resetData(self.model, self.data)
ob = self.reset_model() ob = self.reset_model()
self.renderer.reset()
self.renderer.render_step()
if not return_info: if not return_info:
return ob return ob
else: else:
@@ -195,7 +213,33 @@ class MujocoEnv(gym.Env):
camera_id=None, camera_id=None,
camera_name=None, camera_name=None,
): ):
if mode == "rgb_array" or mode == "depth_array": if self.render_mode is not None:
return self.renderer.get_renders()
else:
return self._render(
mode=mode,
width=width,
height=height,
camera_id=camera_id,
camera_name=camera_name,
)
def _render(
self,
mode="human",
width=DEFAULT_SIZE,
height=DEFAULT_SIZE,
camera_id=None,
camera_name=None,
):
assert mode in self.metadata["render_modes"]
if mode in {
"rgb_array",
"single_rgb_array",
"depth_array",
"single_depth_array",
}:
if camera_id is not None and camera_name is not None: if camera_id is not None and camera_name is not None:
raise ValueError( raise ValueError(
"Both `camera_id` and `camera_name` cannot be" "Both `camera_id` and `camera_name` cannot be"
@@ -219,11 +263,11 @@ class MujocoEnv(gym.Env):
self._get_viewer(mode).render(width, height, camera_id=camera_id) self._get_viewer(mode).render(width, height, camera_id=camera_id)
if mode == "rgb_array": if mode in {"rgb_array", "single_rgb_array"}:
data = self._get_viewer(mode).read_pixels(width, height, depth=False) data = self._get_viewer(mode).read_pixels(width, height, depth=False)
# original image is upside-down, so flip it # original image is upside-down, so flip it
return data[::-1, :, :] return data[::-1, :, :]
elif mode == "depth_array": elif mode in {"depth_array", "single_depth_array"}:
self._get_viewer(mode).render(width, height) self._get_viewer(mode).render(width, height)
# Extract depth part of the read_pixels() tuple # Extract depth part of the read_pixels() tuple
data = self._get_viewer(mode).read_pixels(width, height, depth=True)[1] data = self._get_viewer(mode).read_pixels(width, height, depth=True)[1]
@@ -249,7 +293,12 @@ class MujocoEnv(gym.Env):
from gym.envs.mujoco.mujoco_rendering import Viewer from gym.envs.mujoco.mujoco_rendering import Viewer
self.viewer = Viewer(self.model, self.data) self.viewer = Viewer(self.model, self.data)
elif mode == "rgb_array" or mode == "depth_array": elif mode in {
"rgb_array",
"depth_array",
"single_rgb_array",
"single_depth_array",
}:
if self._mujoco_bindings.__name__ == "mujoco_py": if self._mujoco_bindings.__name__ == "mujoco_py":
self.viewer = self._mujoco_bindings.MjRenderContextOffscreen( self.viewer = self._mujoco_bindings.MjRenderContextOffscreen(
self.sim, -1 self.sim, -1

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -5,10 +7,10 @@ from gym.envs.mujoco import mujoco_env
class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle): class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self): def __init__(self, render_mode: Optional[str] = None):
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, "pusher.xml", 5, mujoco_bindings="mujoco_py" self, "pusher.xml", 5, render_mode=render_mode, mujoco_bindings="mujoco_py"
) )
def step(self, a): def step(self, a):
@@ -21,6 +23,9 @@ class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
reward = reward_dist + 0.1 * reward_ctrl + 0.5 * reward_near reward = reward_dist + 0.1 * reward_ctrl + 0.5 * reward_near
self.do_simulation(a, self.frame_skip) self.do_simulation(a, self.frame_skip)
self.renderer.render_step()
ob = self._get_obs() ob = self._get_obs()
done = False done = False
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl) return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -132,9 +134,9 @@ class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
""" """
def __init__(self): def __init__(self, render_mode: Optional[str] = None):
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
mujoco_env.MujocoEnv.__init__(self, "pusher.xml", 5) mujoco_env.MujocoEnv.__init__(self, "pusher.xml", 5, render_mode=render_mode)
def step(self, a): def step(self, a):
vec_1 = self.get_body_com("object") - self.get_body_com("tips_arm") vec_1 = self.get_body_com("object") - self.get_body_com("tips_arm")
@@ -148,6 +150,9 @@ class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self.do_simulation(a, self.frame_skip) self.do_simulation(a, self.frame_skip)
ob = self._get_obs() ob = self._get_obs()
done = False done = False
self.renderer.render_step()
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl) return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
def viewer_setup(self): def viewer_setup(self):

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -5,10 +7,10 @@ from gym.envs.mujoco import mujoco_env
class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle): class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self): def __init__(self, render_mode: Optional[str] = None):
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, "reacher.xml", 2, mujoco_bindings="mujoco_py" self, "reacher.xml", 2, render_mode=render_mode, mujoco_bindings="mujoco_py"
) )
def step(self, a): def step(self, a):
@@ -17,6 +19,9 @@ class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
reward_ctrl = -np.square(a).sum() reward_ctrl = -np.square(a).sum()
reward = reward_dist + reward_ctrl reward = reward_dist + reward_ctrl
self.do_simulation(a, self.frame_skip) self.do_simulation(a, self.frame_skip)
self.renderer.render_step()
ob = self._get_obs() ob = self._get_obs()
done = False done = False
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl) return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -122,9 +124,9 @@ class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
""" """
def __init__(self): def __init__(self, render_mode: Optional[str] = None):
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
mujoco_env.MujocoEnv.__init__(self, "reacher.xml", 2) mujoco_env.MujocoEnv.__init__(self, "reacher.xml", 2, render_mode=render_mode)
def step(self, a): def step(self, a):
vec = self.get_body_com("fingertip") - self.get_body_com("target") vec = self.get_body_com("fingertip") - self.get_body_com("target")
@@ -134,6 +136,9 @@ class ReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self.do_simulation(a, self.frame_skip) self.do_simulation(a, self.frame_skip)
ob = self._get_obs() ob = self._get_obs()
done = False done = False
self.renderer.render_step()
return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl) return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
def viewer_setup(self): def viewer_setup(self):

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -5,9 +7,9 @@ from gym.envs.mujoco import mujoco_env
class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle): class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self): def __init__(self, render_mode: Optional[str] = None):
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, "swimmer.xml", 4, mujoco_bindings="mujoco_py" self, "swimmer.xml", 4, render_mode=render_mode, mujoco_bindings="mujoco_py"
) )
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
@@ -16,6 +18,9 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
xposbefore = self.sim.data.qpos[0] xposbefore = self.sim.data.qpos[0]
self.do_simulation(a, self.frame_skip) self.do_simulation(a, self.frame_skip)
xposafter = self.sim.data.qpos[0] xposafter = self.sim.data.qpos[0]
self.renderer.render_step()
reward_fwd = (xposafter - xposbefore) / self.dt reward_fwd = (xposafter - xposbefore) / self.dt
reward_ctrl = -ctrl_cost_coeff * np.square(a).sum() reward_ctrl = -ctrl_cost_coeff * np.square(a).sum()
reward = reward_fwd + reward_ctrl reward = reward_fwd + reward_ctrl

View File

@@ -1,5 +1,7 @@
__credits__ = ["Rushiv Arora"] __credits__ = ["Rushiv Arora"]
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -11,6 +13,7 @@ DEFAULT_CAMERA_CONFIG = {}
class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle): class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__( def __init__(
self, self,
render_mode: Optional[str] = None,
xml_file="swimmer.xml", xml_file="swimmer.xml",
forward_reward_weight=1.0, forward_reward_weight=1.0,
ctrl_cost_weight=1e-4, ctrl_cost_weight=1e-4,
@@ -28,7 +31,9 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
mujoco_env.MujocoEnv.__init__(self, xml_file, 4, mujoco_bindings="mujoco_py") mujoco_env.MujocoEnv.__init__(
self, xml_file, 4, render_mode=render_mode, mujoco_bindings="mujoco_py"
)
def control_cost(self, action): def control_cost(self, action):
control_cost = self._ctrl_cost_weight * np.sum(np.square(action)) control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
@@ -39,11 +44,12 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
self.do_simulation(action, self.frame_skip) self.do_simulation(action, self.frame_skip)
xy_position_after = self.sim.data.qpos[0:2].copy() xy_position_after = self.sim.data.qpos[0:2].copy()
self.renderer.render_step()
xy_velocity = (xy_position_after - xy_position_before) / self.dt xy_velocity = (xy_position_after - xy_position_before) / self.dt
x_velocity, y_velocity = xy_velocity x_velocity, y_velocity = xy_velocity
forward_reward = self._forward_reward_weight * x_velocity forward_reward = self._forward_reward_weight * x_velocity
ctrl_cost = self.control_cost(action) ctrl_cost = self.control_cost(action)
observation = self._get_obs() observation = self._get_obs()

View File

@@ -1,5 +1,7 @@
__credits__ = ["Rushiv Arora"] __credits__ = ["Rushiv Arora"]
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -132,6 +134,7 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__( def __init__(
self, self,
render_mode: Optional[str] = None,
xml_file="swimmer.xml", xml_file="swimmer.xml",
forward_reward_weight=1.0, forward_reward_weight=1.0,
ctrl_cost_weight=1e-4, ctrl_cost_weight=1e-4,
@@ -149,7 +152,7 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
mujoco_env.MujocoEnv.__init__(self, xml_file, 4) mujoco_env.MujocoEnv.__init__(self, xml_file, 4, render_mode=render_mode)
def control_cost(self, action): def control_cost(self, action):
control_cost = self._ctrl_cost_weight * np.sum(np.square(action)) control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
@@ -181,6 +184,7 @@ class SwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
"forward_reward": forward_reward, "forward_reward": forward_reward,
} }
self.renderer.render_step()
return observation, reward, done, info return observation, reward, done, info
def _get_obs(self): def _get_obs(self):

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -5,9 +7,13 @@ from gym.envs.mujoco import mujoco_env
class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle): class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self): def __init__(self, render_mode: Optional[str] = None):
mujoco_env.MujocoEnv.__init__( mujoco_env.MujocoEnv.__init__(
self, "walker2d.xml", 4, mujoco_bindings="mujoco_py" self,
"walker2d.xml",
4,
render_mode=render_mode,
mujoco_bindings="mujoco_py",
) )
utils.EzPickle.__init__(self) utils.EzPickle.__init__(self)
@@ -15,12 +21,16 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
posbefore = self.sim.data.qpos[0] posbefore = self.sim.data.qpos[0]
self.do_simulation(a, self.frame_skip) self.do_simulation(a, self.frame_skip)
posafter, height, ang = self.sim.data.qpos[0:3] posafter, height, ang = self.sim.data.qpos[0:3]
self.renderer.render_step()
alive_bonus = 1.0 alive_bonus = 1.0
reward = (posafter - posbefore) / self.dt reward = (posafter - posbefore) / self.dt
reward += alive_bonus reward += alive_bonus
reward -= 1e-3 * np.square(a).sum() reward -= 1e-3 * np.square(a).sum()
done = not (height > 0.8 and height < 2.0 and ang > -1.0 and ang < 1.0) done = not (height > 0.8 and height < 2.0 and ang > -1.0 and ang < 1.0)
ob = self._get_obs() ob = self._get_obs()
return ob, reward, done, {} return ob, reward, done, {}
def _get_obs(self): def _get_obs(self):

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -14,6 +16,7 @@ DEFAULT_CAMERA_CONFIG = {
class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle): class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__( def __init__(
self, self,
render_mode: Optional[str] = None,
xml_file="walker2d.xml", xml_file="walker2d.xml",
forward_reward_weight=1.0, forward_reward_weight=1.0,
ctrl_cost_weight=1e-3, ctrl_cost_weight=1e-3,
@@ -23,6 +26,7 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
healthy_angle_range=(-1.0, 1.0), healthy_angle_range=(-1.0, 1.0),
reset_noise_scale=5e-3, reset_noise_scale=5e-3,
exclude_current_positions_from_observation=True, exclude_current_positions_from_observation=True,
**kwargs
): ):
utils.EzPickle.__init__(**locals()) utils.EzPickle.__init__(**locals())
@@ -41,7 +45,9 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
mujoco_env.MujocoEnv.__init__(self, xml_file, 4, mujoco_bindings="mujoco_py") mujoco_env.MujocoEnv.__init__(
self, xml_file, 4, render_mode=render_mode, mujoco_bindings="mujoco_py"
)
@property @property
def healthy_reward(self): def healthy_reward(self):
@@ -88,8 +94,9 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
x_position_after = self.sim.data.qpos[0] x_position_after = self.sim.data.qpos[0]
x_velocity = (x_position_after - x_position_before) / self.dt x_velocity = (x_position_after - x_position_before) / self.dt
ctrl_cost = self.control_cost(action) self.renderer.render_step()
ctrl_cost = self.control_cost(action)
forward_reward = self._forward_reward_weight * x_velocity forward_reward = self._forward_reward_weight * x_velocity
healthy_reward = self.healthy_reward healthy_reward = self.healthy_reward

View File

@@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
from gym import utils from gym import utils
@@ -158,6 +160,7 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__( def __init__(
self, self,
render_mode: Optional[str] = None,
xml_file="walker2d.xml", xml_file="walker2d.xml",
forward_reward_weight=1.0, forward_reward_weight=1.0,
ctrl_cost_weight=1e-3, ctrl_cost_weight=1e-3,
@@ -185,7 +188,7 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
exclude_current_positions_from_observation exclude_current_positions_from_observation
) )
mujoco_env.MujocoEnv.__init__(self, xml_file, 4) mujoco_env.MujocoEnv.__init__(self, xml_file, 4, render_mode=render_mode)
@property @property
def healthy_reward(self): def healthy_reward(self):
@@ -248,6 +251,7 @@ class Walker2dEnv(mujoco_env.MujocoEnv, utils.EzPickle):
"x_velocity": x_velocity, "x_velocity": x_velocity,
} }
self.renderer.render_step()
return observation, reward, done, info return observation, reward, done, info
def reset_model(self): def reset_model(self):

View File

@@ -6,6 +6,7 @@ import numpy as np
import gym import gym
from gym import spaces from gym import spaces
from gym.error import DependencyNotInstalled from gym.error import DependencyNotInstalled
from gym.utils.renderer import Renderer
def cmp(a, b): def cmp(a, b):
@@ -110,9 +111,12 @@ class BlackjackEnv(gym.Env):
* v0: Initial versions release (1.0.0) * v0: Initial versions release (1.0.0)
""" """
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4} metadata = {
"render_modes": ["human", "rgb_array", "single_rgb_array"],
"render_fps": 4,
}
def __init__(self, natural=False, sab=False): def __init__(self, render_mode: Optional[str] = None, natural=False, sab=False):
self.action_space = spaces.Discrete(2) self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Tuple( self.observation_space = spaces.Tuple(
(spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2)) (spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2))
@@ -125,6 +129,10 @@ class BlackjackEnv(gym.Env):
# Flag for full agreement with the (Sutton and Barto, 2018) definition. Overrides self.natural # Flag for full agreement with the (Sutton and Barto, 2018) definition. Overrides self.natural
self.sab = sab self.sab = sab
assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
def step(self, action): def step(self, action):
assert self.action_space.contains(action) assert self.action_space.contains(action)
if action: # hit: add a card to players hand and return if action: # hit: add a card to players hand and return
@@ -151,6 +159,8 @@ class BlackjackEnv(gym.Env):
): ):
# Natural gives extra points, but doesn't autowin. Legacy implementation # Natural gives extra points, but doesn't autowin. Legacy implementation
reward = 1.5 reward = 1.5
self.renderer.render_step()
return self._get_obs(), reward, done, {} return self._get_obs(), reward, done, {}
def _get_obs(self): def _get_obs(self):
@@ -165,12 +175,24 @@ class BlackjackEnv(gym.Env):
super().reset(seed=seed) super().reset(seed=seed)
self.dealer = draw_hand(self.np_random) self.dealer = draw_hand(self.np_random)
self.player = draw_hand(self.np_random) self.player = draw_hand(self.np_random)
self.renderer.reset()
self.renderer.render_step()
if not return_info: if not return_info:
return self._get_obs() return self._get_obs()
else: else:
return self._get_obs(), {} return self._get_obs(), {}
def render(self, mode="human"): def render(self, mode="human"):
if self.render_mode is not None:
return self.renderer.get_renders()
else:
return self._render(mode)
def _render(self, mode):
assert mode in self.metadata["render_modes"]
try: try:
import pygame import pygame
except ImportError: except ImportError:

View File

@@ -7,6 +7,7 @@ import numpy as np
from gym import Env, spaces from gym import Env, spaces
from gym.envs.toy_text.utils import categorical_sample from gym.envs.toy_text.utils import categorical_sample
from gym.utils.renderer import Renderer
UP = 0 UP = 0
RIGHT = 1 RIGHT = 1
@@ -62,7 +63,7 @@ class CliffWalkingEnv(Env):
metadata = {"render_modes": ["human", "ansi"], "render_fps": 4} metadata = {"render_modes": ["human", "ansi"], "render_fps": 4}
def __init__(self): def __init__(self, render_mode: Optional[str] = None):
self.shape = (4, 12) self.shape = (4, 12)
self.start_state_index = np.ravel_multi_index((3, 0), self.shape) self.start_state_index = np.ravel_multi_index((3, 0), self.shape)
@@ -91,6 +92,10 @@ class CliffWalkingEnv(Env):
self.observation_space = spaces.Discrete(self.nS) self.observation_space = spaces.Discrete(self.nS)
self.action_space = spaces.Discrete(self.nA) self.action_space = spaces.Discrete(self.nA)
assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
def _limit_coordinates(self, coord: np.ndarray) -> np.ndarray: def _limit_coordinates(self, coord: np.ndarray) -> np.ndarray:
"""Prevent the agent from falling out of the grid world.""" """Prevent the agent from falling out of the grid world."""
coord[0] = min(coord[0], self.shape[0] - 1) coord[0] = min(coord[0], self.shape[0] - 1)
@@ -125,6 +130,7 @@ class CliffWalkingEnv(Env):
p, s, r, d = transitions[i] p, s, r, d = transitions[i]
self.s = s self.s = s
self.lastaction = a self.lastaction = a
self.renderer.render_step()
return (int(s), r, d, {"prob": p}) return (int(s), r, d, {"prob": p})
def reset( def reset(
@@ -137,12 +143,21 @@ class CliffWalkingEnv(Env):
super().reset(seed=seed) super().reset(seed=seed)
self.s = categorical_sample(self.initial_state_distrib, self.np_random) self.s = categorical_sample(self.initial_state_distrib, self.np_random)
self.lastaction = None self.lastaction = None
self.renderer.reset()
self.renderer.render_step()
if not return_info: if not return_info:
return int(self.s) return int(self.s)
else: else:
return int(self.s), {"prob": 1} return int(self.s), {"prob": 1}
def render(self, mode="human"): def render(self, mode="human"):
if self.render_mode is not None:
return self.renderer.get_renders()
else:
return self._render(mode)
def _render(self, mode):
assert mode in self.metadata["render_modes"]
outfile = StringIO() if mode == "ansi" else sys.stdout outfile = StringIO() if mode == "ansi" else sys.stdout
for s in range(self.nS): for s in range(self.nS):

View File

@@ -8,6 +8,7 @@ import numpy as np
from gym import Env, spaces, utils from gym import Env, spaces, utils
from gym.envs.toy_text.utils import categorical_sample from gym.envs.toy_text.utils import categorical_sample
from gym.error import DependencyNotInstalled from gym.error import DependencyNotInstalled
from gym.utils.renderer import Renderer
LEFT = 0 LEFT = 0
DOWN = 1 DOWN = 1
@@ -144,9 +145,18 @@ class FrozenLakeEnv(Env):
* v0: Initial versions release (1.0.0) * v0: Initial versions release (1.0.0)
""" """
metadata = {"render_modes": ["human", "ansi", "rgb_array"], "render_fps": 4} metadata = {
"render_modes": ["human", "ansi", "rgb_array", "single_rgb_array"],
"render_fps": 4,
}
def __init__(self, desc=None, map_name="4x4", is_slippery=True): def __init__(
self,
render_mode: Optional[str] = None,
desc=None,
map_name="4x4",
is_slippery=True,
):
if desc is None and map_name is None: if desc is None and map_name is None:
desc = generate_random_map() desc = generate_random_map()
elif desc is None: elif desc is None:
@@ -205,6 +215,10 @@ class FrozenLakeEnv(Env):
self.observation_space = spaces.Discrete(nS) self.observation_space = spaces.Discrete(nS)
self.action_space = spaces.Discrete(nA) self.action_space = spaces.Discrete(nA)
assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
# pygame utils # pygame utils
self.window_size = (min(64 * ncol, 512), min(64 * nrow, 512)) self.window_size = (min(64 * ncol, 512), min(64 * nrow, 512))
self.window_surface = None self.window_surface = None
@@ -222,6 +236,9 @@ class FrozenLakeEnv(Env):
p, s, r, d = transitions[i] p, s, r, d = transitions[i]
self.s = s self.s = s
self.lastaction = a self.lastaction = a
self.renderer.render_step()
return (int(s), r, d, {"prob": p}) return (int(s), r, d, {"prob": p})
def reset( def reset(
@@ -235,19 +252,28 @@ class FrozenLakeEnv(Env):
self.s = categorical_sample(self.initial_state_distrib, self.np_random) self.s = categorical_sample(self.initial_state_distrib, self.np_random)
self.lastaction = None self.lastaction = None
self.renderer.reset()
self.renderer.render_step()
if not return_info: if not return_info:
return int(self.s) return int(self.s)
else: else:
return int(self.s), {"prob": 1} return int(self.s), {"prob": 1}
def render(self, mode="human"): def render(self, mode="human"):
desc = self.desc.tolist() if self.render_mode is not None:
if mode == "ansi": return self.renderer.get_renders()
return self._render_text(desc)
else: else:
return self._render_gui(desc, mode) return self._render(mode)
def _render_gui(self, desc, mode): def _render(self, mode="human"):
assert mode in self.metadata["render_modes"]
if mode == "ansi":
return self._render_text()
elif mode in {"human", "rgb_array", "single_rgb_array"}:
return self._render_gui(mode)
def _render_gui(self, mode):
try: try:
import pygame import pygame
except ImportError: except ImportError:
@@ -261,7 +287,7 @@ class FrozenLakeEnv(Env):
pygame.display.set_caption("Frozen Lake") pygame.display.set_caption("Frozen Lake")
if mode == "human": if mode == "human":
self.window_surface = pygame.display.set_mode(self.window_size) self.window_surface = pygame.display.set_mode(self.window_size)
else: # rgb_array elif mode in {"rgb_array", "single_rgb_array"}:
self.window_surface = pygame.Surface(self.window_size) self.window_surface = pygame.Surface(self.window_size)
if self.clock is None: if self.clock is None:
self.clock = pygame.time.Clock() self.clock = pygame.time.Clock()
@@ -315,6 +341,7 @@ class FrozenLakeEnv(Env):
goal_img = pygame.transform.scale(self.goal_img, (cell_width, cell_height)) goal_img = pygame.transform.scale(self.goal_img, (cell_width, cell_height))
start_img = pygame.transform.scale(self.start_img, (small_cell_w, small_cell_h)) start_img = pygame.transform.scale(self.start_img, (small_cell_w, small_cell_h))
desc = self.desc.tolist()
for y in range(self.nrow): for y in range(self.nrow):
for x in range(self.ncol): for x in range(self.ncol):
rect = (x * cell_width, y * cell_height, cell_width, cell_height) rect = (x * cell_width, y * cell_height, cell_width, cell_height)
@@ -351,7 +378,7 @@ class FrozenLakeEnv(Env):
pygame.event.pump() pygame.event.pump()
pygame.display.update() pygame.display.update()
self.clock.tick(self.metadata["render_fps"]) self.clock.tick(self.metadata["render_fps"])
else: # rgb_array elif mode in {"rgb_array", "single_rgb_array"}:
return np.transpose( return np.transpose(
np.array(pygame.surfarray.pixels3d(self.window_surface)), axes=(1, 0, 2) np.array(pygame.surfarray.pixels3d(self.window_surface)), axes=(1, 0, 2)
) )
@@ -365,7 +392,8 @@ class FrozenLakeEnv(Env):
big_rect[1] + offset_h, big_rect[1] + offset_h,
) )
def _render_text(self, desc): def _render_text(self):
desc = self.desc.tolist()
outfile = StringIO() outfile = StringIO()
row, col = self.s // self.ncol, self.s % self.ncol row, col = self.s // self.ncol, self.s % self.ncol

View File

@@ -8,6 +8,7 @@ import numpy as np
from gym import Env, spaces, utils from gym import Env, spaces, utils
from gym.envs.toy_text.utils import categorical_sample from gym.envs.toy_text.utils import categorical_sample
from gym.error import DependencyNotInstalled from gym.error import DependencyNotInstalled
from gym.utils.renderer import Renderer
MAP = [ MAP = [
"+---------+", "+---------+",
@@ -104,9 +105,12 @@ class TaxiEnv(Env):
* v0: Initial versions release * v0: Initial versions release
""" """
metadata = {"render_modes": ["human", "ansi", "rgb_array"], "render_fps": 4} metadata = {
"render_modes": ["human", "ansi", "rgb_array", "single_rgb_array"],
"render_fps": 4,
}
def __init__(self): def __init__(self, render_mode: Optional[str] = None):
self.desc = np.asarray(MAP, dtype="c") self.desc = np.asarray(MAP, dtype="c")
self.locs = locs = [(0, 0), (0, 4), (4, 0), (4, 3)] self.locs = locs = [(0, 0), (0, 4), (4, 0), (4, 3)]
@@ -169,6 +173,10 @@ class TaxiEnv(Env):
self.action_space = spaces.Discrete(num_actions) self.action_space = spaces.Discrete(num_actions)
self.observation_space = spaces.Discrete(num_states) self.observation_space = spaces.Discrete(num_states)
assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
# pygame utils # pygame utils
self.window = None self.window = None
self.clock = None self.clock = None
@@ -213,6 +221,7 @@ class TaxiEnv(Env):
p, s, r, d = transitions[i] p, s, r, d = transitions[i]
self.s = s self.s = s
self.lastaction = a self.lastaction = a
self.renderer.render_step()
return (int(s), r, d, {"prob": p}) return (int(s), r, d, {"prob": p})
def reset( def reset(
@@ -226,15 +235,24 @@ class TaxiEnv(Env):
self.s = categorical_sample(self.initial_state_distrib, self.np_random) self.s = categorical_sample(self.initial_state_distrib, self.np_random)
self.lastaction = None self.lastaction = None
self.taxi_orientation = 0 self.taxi_orientation = 0
self.renderer.reset()
self.renderer.render_step()
if not return_info: if not return_info:
return int(self.s) return int(self.s)
else: else:
return int(self.s), {"prob": 1} return int(self.s), {"prob": 1}
def render(self, mode="human"): def render(self, mode="human"):
if self.render_mode is not None:
return self.renderer.get_renders()
else:
return self._render(mode)
def _render(self, mode):
assert mode in self.metadata["render_modes"]
if mode == "ansi": if mode == "ansi":
return self._render_text() return self._render_text()
else: elif mode in {"human", "rgb_array", "single_rgb_array"}:
return self._render_gui(mode) return self._render_gui(mode)
def _render_gui(self, mode): def _render_gui(self, mode):
@@ -250,7 +268,7 @@ class TaxiEnv(Env):
pygame.display.set_caption("Taxi") pygame.display.set_caption("Taxi")
if mode == "human": if mode == "human":
self.window = pygame.display.set_mode(WINDOW_SIZE) self.window = pygame.display.set_mode(WINDOW_SIZE)
else: # "rgb_array" elif mode in {"rgb_array", "single_rgb_array"}:
self.window = pygame.Surface(WINDOW_SIZE) self.window = pygame.Surface(WINDOW_SIZE)
if self.clock is None: if self.clock is None:
self.clock = pygame.time.Clock() self.clock = pygame.time.Clock()
@@ -358,7 +376,7 @@ class TaxiEnv(Env):
if mode == "human": if mode == "human":
pygame.display.update() pygame.display.update()
self.clock.tick(self.metadata["render_fps"]) self.clock.tick(self.metadata["render_fps"])
else: # rgb_array elif mode in {"rgb_array", "single_rgb_array"}:
return np.transpose( return np.transpose(
np.array(pygame.surfarray.pixels3d(self.window)), axes=(1, 0, 2) np.array(pygame.surfarray.pixels3d(self.window)), axes=(1, 0, 2)
) )

View File

@@ -170,37 +170,39 @@ def check_reset_options(env: gym.Env):
) )
# Check render cannot be covered by CI def check_render(env: gym.Env, warn: bool = True):
def check_render(env: gym.Env, headless: bool = False): """Check the declared render modes/fps of the environment.
"""Check the declared render modes/fps and the :meth:`render`/:meth:`close` method of the environment.
Args: Args:
env: The environment to check env: The environment to check
headless: Whether to disable render modes that require a graphical interface. False by default. warn: Whether to output additional warnings
""" """
render_modes = env.metadata.get("render_modes") render_modes = env.metadata.get("render_modes")
if render_modes is None: if render_modes is None:
logger.warn( if warn:
"No render modes was declared in the environment (env.metadata['render_modes'] is None or not defined), you may have trouble when calling `.render()`" logger.warn(
) "No render modes was declared in the environment "
" (env.metadata['render_modes'] is None or not defined), "
"you may have trouble when calling `.render()`"
)
render_fps = env.metadata.get("render_fps") render_fps = env.metadata.get("render_fps")
# We only require `render_fps` if rendering is actually implemented # We only require `render_fps` if rendering is actually implemented
if render_fps is None: if render_fps is None and render_modes is not None and len(render_modes) > 0:
logger.warn( if warn:
"No render fps was declared in the environment (env.metadata['render_fps'] is None or not defined), rendering may occur at inconsistent fps" logger.warn(
) "No render fps was declared in the environment "
" (env.metadata['render_fps'] is None or not defined), "
"rendering may occur at inconsistent fps"
)
if render_modes is not None: if warn:
# Don't check render mode that require a if not hasattr(env, "render_mode"): # TODO: raise an error with gym 1.0
# graphical interface (useful for CI) logger.warn("Environments must define render_mode attribute.")
if headless and "human" in render_modes: elif env.render_mode is not None and env.render_mode not in render_modes:
render_modes.remove("human") logger.warn(
"The environment was initialized successfully with an unsupported render mode."
# Check all declared render modes )
for mode in render_modes:
env.render(mode=mode)
env.close()
def check_env(env: gym.Env, warn: bool = None, skip_render_check: bool = True): def check_env(env: gym.Env, warn: bool = None, skip_render_check: bool = True):

79
gym/utils/renderer.py Normal file
View File

@@ -0,0 +1,79 @@
"""A utility class to collect render frames from a function that computes a single frame."""
from typing import Any, Callable, List, Optional, Set
# list of modes with which render function returns None
NO_RETURNS_RENDER = {"human"}
# list of modes with which render returns just a single frame of the current state
SINGLE_RENDER = {"single_rgb_array", "single_depth_array", "single_state_pixels"}
class Renderer:
"""This class serves to easily integrate collection of renders for environments that can computes a single render.
To use this function:
- instantiate this class with the mode and the function that computes a single frame
- call render_step method each time the frame should be saved in the list
(usually at the end of the step and reset methods)
- call get_renders whenever you want to retrieve renders
(usually in the render method)
- call reset to clean the render list
(usually in the reset method of the environment)
"""
def __init__(
self,
mode: Optional[str],
render: Callable[[str], Any],
no_returns_render: Optional[Set[str]] = None,
single_render: Optional[Set[str]] = None,
):
"""Instantiates a Renderer object.
Args:
mode (Optional[str]): Way to render
render (Callable[[str], Any]): Function that receives the mode and computes a single frame
no_returns_render (Optional[Set[str]]): Set of render modes that don't return any value.
The default value is the set {"human"}.
single_render (Optional[Set[str]]): Set of render modes that should return a single frame.
The default value is the set {"single_rgb_array", "single_depth_array", "single_state_pixels"}.
"""
if no_returns_render is None:
no_returns_render = NO_RETURNS_RENDER
if single_render is None:
single_render = SINGLE_RENDER
self.no_returns_render = no_returns_render
self.single_render = single_render
self.mode = mode
self.render = render
self.render_list = []
def render_step(self) -> None:
"""Computes a frame and save it to the render collection list.
This method should be usually called inside environment's step and reset method.
"""
if self.mode is not None and self.mode not in SINGLE_RENDER:
render_return = self.render(self.mode)
if self.mode not in NO_RETURNS_RENDER:
self.render_list.append(render_return)
def get_renders(self) -> Optional[List]:
"""Pops all the frames from the render collection list.
This method should be usually called in the environment's render method to retrieve the frames collected till this time step.
"""
if self.mode in SINGLE_RENDER:
return self.render(self.mode)
elif self.mode is not None and self.mode not in NO_RETURNS_RENDER:
renders = self.render_list
self.render_list = []
return renders
def reset(self):
"""Resets the render collection list.
This method should be usually called inside environment's reset method.
"""
self.render_list = []

View File

@@ -7,7 +7,7 @@ import shutil
import subprocess import subprocess
import tempfile import tempfile
from io import StringIO from io import StringIO
from typing import Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
@@ -19,11 +19,16 @@ def touch(path: str):
open(path, "a").close() open(path, "a").close()
class VideoRecorder: class VideoRecorder: # TODO: remove with gym 1.0
"""VideoRecorder renders a nice movie of a rollout, frame by frame. """VideoRecorder renders a nice movie of a rollout, frame by frame.
It comes with an ``enabled`` option, so you can still use the same code on episodes where you don't want to record video. It comes with an ``enabled`` option, so you can still use the same code on episodes where you don't want to record video.
Note:
VideoRecorder is deprecated.
Collect the frames with render_mode='rgb_array' and use an external library like MoviePy:
https://zulko.github.io/moviepy/getting_started/videoclips.html#videoclip
Note: Note:
You are responsible for calling :meth:`close` on a created VideoRecorder, or else you may leak an encoder process. You are responsible for calling :meth:`close` on a created VideoRecorder, or else you may leak an encoder process.
""" """
@@ -50,6 +55,11 @@ class VideoRecorder:
Error: Invalid path given that must have a particular file extension Error: Invalid path given that must have a particular file extension
""" """
modes = env.metadata.get("render_modes", []) modes = env.metadata.get("render_modes", [])
logger.deprecation(
"VideoRecorder is deprecated.\n"
"Collect the frames with render_mode='rgb_array' and use an external library like MoviePy: "
"https://zulko.github.io/moviepy/getting_started/videoclips.html#videoclip"
)
# backward-compatibility mode: # backward-compatibility mode:
backward_compatible_mode = env.metadata.get("render.modes", []) backward_compatible_mode = env.metadata.get("render.modes", [])
@@ -64,10 +74,6 @@ class VideoRecorder:
self.enabled = enabled self.enabled = enabled
self._closed = False self._closed = False
# Don't bother setting anything else if not enabled
if not self.enabled:
return
self.ansi_mode = False self.ansi_mode = False
if "rgb_array" not in modes: if "rgb_array" not in modes:
if "ansi" in modes: if "ansi" in modes:
@@ -78,7 +84,10 @@ class VideoRecorder:
) )
# Whoops, turns out we shouldn't be enabled after all # Whoops, turns out we shouldn't be enabled after all
self.enabled = False self.enabled = False
return
# Don't bother setting anything else if not enabled
if not self.enabled:
return
if path is not None and base_path is not None: if path is not None and base_path is not None:
raise error.Error("You can pass at most one of `path` or `base_path`.") raise error.Error("You can pass at most one of `path` or `base_path`.")
@@ -171,6 +180,8 @@ class VideoRecorder:
render_mode = "ansi" if self.ansi_mode else "rgb_array" render_mode = "ansi" if self.ansi_mode else "rgb_array"
frame = self.env.render(mode=render_mode) frame = self.env.render(mode=render_mode)
if isinstance(frame, List):
frame = frame[-1]
if frame is None: if frame is None:
if self._async: if self._async:

View File

@@ -2,7 +2,7 @@
import collections import collections
import copy import copy
from collections.abc import MutableMapping from collections.abc import MutableMapping
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import numpy as np import numpy as np
@@ -95,10 +95,6 @@ class PixelObservationWrapper(gym.ObservationWrapper):
for key in pixel_keys: for key in pixel_keys:
render_kwargs.setdefault(key, {}) render_kwargs.setdefault(key, {})
render_mode = render_kwargs[key].pop("mode", "rgb_array")
assert render_mode == "rgb_array", render_mode
render_kwargs[key]["mode"] = "rgb_array"
wrapped_observation_space = env.observation_space wrapped_observation_space = env.observation_space
if isinstance(wrapped_observation_space, spaces.Box): if isinstance(wrapped_observation_space, spaces.Box):
@@ -133,6 +129,7 @@ class PixelObservationWrapper(gym.ObservationWrapper):
pixels_spaces = {} pixels_spaces = {}
for pixel_key in pixel_keys: for pixel_key in pixel_keys:
pixels = self.env.render(**render_kwargs[pixel_key]) pixels = self.env.render(**render_kwargs[pixel_key])
pixels = pixels[-1] if isinstance(pixels, List) else pixels
if np.issubdtype(pixels.dtype, np.integer): if np.issubdtype(pixels.dtype, np.integer):
low, high = (0, 255) low, high = (0, 255)

View File

@@ -24,7 +24,7 @@ def capped_cubic_video_schedule(episode_id: int) -> bool:
return episode_id % 1000 == 0 return episode_id % 1000 == 0
class RecordVideo(gym.Wrapper): class RecordVideo(gym.Wrapper): # TODO: remove with gym 1.0
"""This wrapper records videos of rollouts. """This wrapper records videos of rollouts.
Usually, you only want to record episodes intermittently, say every hundredth episode. Usually, you only want to record episodes intermittently, say every hundredth episode.
@@ -35,6 +35,11 @@ class RecordVideo(gym.Wrapper):
By default, the recording will be stopped once a `done` signal has been emitted by the environment. However, you can By default, the recording will be stopped once a `done` signal has been emitted by the environment. However, you can
also create recordings of fixed length (possibly spanning several episodes) by passing a strictly positive value for also create recordings of fixed length (possibly spanning several episodes) by passing a strictly positive value for
``video_length``. ``video_length``.
Note:
RecordVideo is deprecated.
Collect the frames with render_mode='rgb_array' and use an external library like MoviePy:
https://zulko.github.io/moviepy/getting_started/videoclips.html#videoclip
""" """
def __init__( def __init__(
@@ -58,6 +63,11 @@ class RecordVideo(gym.Wrapper):
name_prefix (str): Will be prepended to the filename of the recordings name_prefix (str): Will be prepended to the filename of the recordings
""" """
super().__init__(env) super().__init__(env)
logger.deprecation(
"RecordVideo is deprecated.\n"
"Collect the frames with render_mode='rgb_array' and use an external library like MoviePy: "
"https://zulko.github.io/moviepy/getting_started/videoclips.html#videoclip"
)
if episode_trigger is None and step_trigger is None: if episode_trigger is None and step_trigger is None:
episode_trigger = capped_cubic_video_schedule episode_trigger = capped_cubic_video_schedule
@@ -90,7 +100,13 @@ class RecordVideo(gym.Wrapper):
def reset(self, **kwargs): def reset(self, **kwargs):
"""Reset the environment using kwargs and then starts recording if video enabled.""" """Reset the environment using kwargs and then starts recording if video enabled."""
observations = super().reset(**kwargs) observations = super().reset(**kwargs)
if not self.recording and self._video_enabled(): if self.recording:
self.video_recorder.capture_frame()
self.recorded_frames += 1
if self.video_length > 0:
if self.recorded_frames > self.video_length:
self.close_video_recorder()
elif self._video_enabled():
self.start_video_recorder() self.start_video_recorder()
return observations return observations

View File

@@ -1,3 +1,5 @@
from typing import List
import numpy as np import numpy as np
import pytest import pytest
@@ -49,15 +51,6 @@ def test_env(spec):
assert ( assert (
observation.dtype == ob_space.dtype observation.dtype == ob_space.dtype
), f"Step observation dtype: {ob.dtype}, expected: {ob_space.dtype}" ), f"Step observation dtype: {ob.dtype}, expected: {ob_space.dtype}"
for mode in env.metadata.get("render_modes", []):
if not (mode == "human" and spec.entry_point.startswith("gym.envs.mujoco")):
env.render(mode=mode)
# Make sure we can render the environment after close.
for mode in env.metadata.get("render_modes", []):
if not (mode == "human" and spec.entry_point.startswith("gym.envs.mujoco")):
env.render(mode=mode)
env.close() env.close()
@@ -79,14 +72,30 @@ def test_reset_info(spec):
env.close() env.close()
@pytest.mark.parametrize(
"spec", spec_list_no_mujoco_py, ids=[spec.id for spec in spec_list_no_mujoco_py]
)
def test_render_modes(spec):
env = spec.make()
for mode in env.metadata.get("render_modes", []):
if mode != "human":
new_env = spec.make(render_mode=mode)
new_env.reset()
new_env.step(new_env.action_space.sample())
new_env.render()
def test_env_render_result_is_immutable(): def test_env_render_result_is_immutable():
environs = [ environs = [
envs.make("Taxi-v3"), envs.make("Taxi-v3", render_mode="ansi"),
envs.make("FrozenLake-v1"), envs.make("FrozenLake-v1", render_mode="ansi"),
] ]
for env in environs: for env in environs:
env.reset() env.reset()
output = env.render(mode="ansi") output = env.render()
assert isinstance(output, str) assert isinstance(output, List)
assert isinstance(output[0], str)
env.close() env.close()

View File

@@ -12,6 +12,9 @@ class ActionDictTestEnv(gym.Env):
action_space = Dict({"position": Discrete(1), "velocity": Discrete(1)}) action_space = Dict({"position": Discrete(1), "velocity": Discrete(1)})
observation_space = Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32) observation_space = Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32)
def __init__(self, render_mode: Optional[str] = None):
self.render_mode = render_mode
def step(self, action): def step(self, action):
observation = np.array([1.0, 1.5, 0.5]) observation = np.array([1.0, 1.5, 0.5])
reward = 1 reward = 1
@@ -22,7 +25,7 @@ class ActionDictTestEnv(gym.Env):
super().reset(seed=seed) super().reset(seed=seed)
return np.array([1.0, 1.5, 0.5]) return np.array([1.0, 1.5, 0.5])
def render(self, mode="human"): def render(self, mode: Optional[str] = "human"):
pass pass

View File

@@ -105,11 +105,11 @@ def test_step_async_vector_env(shared_memory, use_single_action_space):
@pytest.mark.parametrize("shared_memory", [True, False]) @pytest.mark.parametrize("shared_memory", [True, False])
def test_call_async_vector_env(shared_memory): def test_call_async_vector_env(shared_memory):
env_fns = [make_env("CartPole-v1", i) for i in range(4)] env_fns = [make_env("CartPole-v1", i, render_mode="rgb_array") for i in range(4)]
try: try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
_ = env.reset() _ = env.reset()
images = env.call("render", mode="rgb_array") images = env.call("render")
gravity = env.call("gravity") gravity = env.call("gravity")
finally: finally:
env.close() env.close()
@@ -117,7 +117,8 @@ def test_call_async_vector_env(shared_memory):
assert isinstance(images, tuple) assert isinstance(images, tuple)
assert len(images) == 4 assert len(images) == 4
for i in range(4): for i in range(4):
assert isinstance(images[i], np.ndarray) assert len(images[i]) == 1
assert isinstance(images[i][0], np.ndarray)
assert isinstance(gravity, tuple) assert isinstance(gravity, tuple)
assert len(gravity) == 4 assert len(gravity) == 4

View File

@@ -105,11 +105,11 @@ def test_step_sync_vector_env(use_single_action_space):
def test_call_sync_vector_env(): def test_call_sync_vector_env():
env_fns = [make_env("CartPole-v1", i) for i in range(4)] env_fns = [make_env("CartPole-v1", i, render_mode="rgb_array") for i in range(4)]
try: try:
env = SyncVectorEnv(env_fns) env = SyncVectorEnv(env_fns)
_ = env.reset() _ = env.reset()
images = env.call("render", mode="rgb_array") images = env.call("render")
gravity = env.call("gravity") gravity = env.call("gravity")
finally: finally:
env.close() env.close()
@@ -117,7 +117,8 @@ def test_call_sync_vector_env():
assert isinstance(images, tuple) assert isinstance(images, tuple)
assert len(images) == 4 assert len(images) == 4
for i in range(4): for i in range(4):
assert isinstance(images[i], np.ndarray) assert len(images[i]) == 1
assert isinstance(images[i][0], np.ndarray)
assert isinstance(gravity, tuple) assert isinstance(gravity, tuple)
assert len(gravity) == 4 assert len(gravity) == 4

View File

@@ -107,9 +107,9 @@ class CustomSpaceEnv(gym.Env):
return observation, reward, done, {} return observation, reward, done, {}
def make_env(env_name, seed): def make_env(env_name, seed, **kwargs):
def _make(): def _make():
env = gym.make(env_name) env = gym.make(env_name, **kwargs)
env.action_space.seed(seed) env.action_space.seed(seed)
env.reset(seed=seed) env.reset(seed=seed)
return env return env

View File

@@ -9,7 +9,7 @@ from gym.wrappers.filter_observation import FilterObservation
class FakeEnvironment(gym.Env): class FakeEnvironment(gym.Env):
def __init__(self, observation_keys=("state")): def __init__(self, render_mode=None, observation_keys=("state")):
self.observation_space = spaces.Dict( self.observation_space = spaces.Dict(
{ {
name: spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32) name: spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32)
@@ -17,11 +17,10 @@ class FakeEnvironment(gym.Env):
} }
) )
self.action_space = spaces.Box(shape=(1,), low=-1, high=1, dtype=np.float32) self.action_space = spaces.Box(shape=(1,), low=-1, high=1, dtype=np.float32)
self.render_mode = render_mode
def render(self, width=32, height=32, *args, **kwargs): def render(self, mode="human"):
del args image_shape = (32, 32, 3)
del kwargs
image_shape = (height, width, 3)
return np.zeros(image_shape, dtype=np.uint8) return np.zeros(image_shape, dtype=np.uint8)
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):

View File

@@ -10,15 +10,14 @@ from gym.wrappers import FilterObservation, FlattenObservation
class FakeEnvironment(gym.Env): class FakeEnvironment(gym.Env):
def __init__(self, observation_space): def __init__(self, observation_space, render_mode=None):
self.observation_space = observation_space self.observation_space = observation_space
self.obs_keys = self.observation_space.spaces.keys() self.obs_keys = self.observation_space.spaces.keys()
self.action_space = Box(shape=(1,), low=-1, high=1, dtype=np.float32) self.action_space = Box(shape=(1,), low=-1, high=1, dtype=np.float32)
self.render_mode = render_mode
def render(self, width=32, height=32, *args, **kwargs): def render(self, mode="human"):
del args image_shape = (32, 32, 3)
del kwargs
image_shape = (height, width, 3)
return np.zeros(image_shape, dtype=np.uint8) return np.zeros(image_shape, dtype=np.uint8)
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):

View File

@@ -10,12 +10,11 @@ from gym.wrappers.pixel_observation import STATE_KEY, PixelObservationWrapper
class FakeEnvironment(gym.Env): class FakeEnvironment(gym.Env):
def __init__(self): def __init__(self, render_mode=None):
self.action_space = spaces.Box(shape=(1,), low=-1, high=1, dtype=np.float32) self.action_space = spaces.Box(shape=(1,), low=-1, high=1, dtype=np.float32)
self.render_mode = render_mode
def render(self, width=32, height=32, *args, **kwargs): def render(self, mode="human", width=32, height=32):
del args
del kwargs
image_shape = (height, width, 3) image_shape = (height, width, 3)
return np.zeros(image_shape, dtype=np.uint8) return np.zeros(image_shape, dtype=np.uint8)
@@ -49,7 +48,7 @@ class FakeDictObservationEnvironment(FakeEnvironment):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
class TestPixelObservationWrapper: class TestPixelObservationWrapper(gym.Wrapper):
@pytest.mark.parametrize("pixels_only", (True, False)) @pytest.mark.parametrize("pixels_only", (True, False))
def test_dict_observation(self, pixels_only): def test_dict_observation(self, pixels_only):
pixel_key = "rgb" pixel_key = "rgb"

View File

@@ -42,7 +42,9 @@ def test_record_episode_statistics_reset_info():
("num_envs", "asynchronous"), [(1, False), (1, True), (4, False), (4, True)] ("num_envs", "asynchronous"), [(1, False), (1, True), (4, False), (4, True)]
) )
def test_record_episode_statistics_with_vectorenv(num_envs, asynchronous): def test_record_episode_statistics_with_vectorenv(num_envs, asynchronous):
envs = gym.vector.make("CartPole-v1", num_envs=num_envs, asynchronous=asynchronous) envs = gym.vector.make(
"CartPole-v1", render_mode=None, num_envs=num_envs, asynchronous=asynchronous
)
envs = RecordEpisodeStatistics(envs) envs = RecordEpisodeStatistics(envs)
max_episode_step = ( max_episode_step = (
envs.env_fns[0]().spec.max_episode_steps envs.env_fns[0]().spec.max_episode_steps

View File

@@ -7,7 +7,7 @@ from gym.wrappers import capped_cubic_video_schedule
def test_record_video_using_default_trigger(): def test_record_video_using_default_trigger():
env = gym.make("CartPole-v1") env = gym.make("CartPole-v1", render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, "videos") env = gym.wrappers.RecordVideo(env, "videos")
env.reset() env.reset()
for _ in range(199): for _ in range(199):
@@ -25,7 +25,7 @@ def test_record_video_using_default_trigger():
def test_record_video_reset_return_info(): def test_record_video_reset_return_info():
env = gym.make("CartPole-v1") env = gym.make("CartPole-v1", render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0) env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
ob_space = env.observation_space ob_space = env.observation_space
obs, info = env.reset(return_info=True) obs, info = env.reset(return_info=True)
@@ -35,7 +35,7 @@ def test_record_video_reset_return_info():
assert ob_space.contains(obs) assert ob_space.contains(obs)
assert isinstance(info, dict) assert isinstance(info, dict)
env = gym.make("CartPole-v1") env = gym.make("CartPole-v1", render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0) env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
ob_space = env.observation_space ob_space = env.observation_space
obs = env.reset(return_info=False) obs = env.reset(return_info=False)
@@ -44,7 +44,7 @@ def test_record_video_reset_return_info():
shutil.rmtree("videos") shutil.rmtree("videos")
assert ob_space.contains(obs) assert ob_space.contains(obs)
env = gym.make("CartPole-v1") env = gym.make("CartPole-v1", render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0) env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
ob_space = env.observation_space ob_space = env.observation_space
obs = env.reset() obs = env.reset()
@@ -55,7 +55,7 @@ def test_record_video_reset_return_info():
def test_record_video_step_trigger(): def test_record_video_step_trigger():
env = gym.make("CartPole-v1") env = gym.make("CartPole-v1", render_mode="rgb_array")
env._max_episode_steps = 20 env._max_episode_steps = 20
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0) env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
env.reset() env.reset()
@@ -71,9 +71,9 @@ def test_record_video_step_trigger():
shutil.rmtree("videos") shutil.rmtree("videos")
def make_env(gym_id, seed): def make_env(gym_id, seed, **kwargs):
def thunk(): def thunk():
env = gym.make(gym_id) env = gym.make(gym_id, **kwargs)
env._max_episode_steps = 20 env._max_episode_steps = 20
if seed == 1: if seed == 1:
env = gym.wrappers.RecordVideo( env = gym.wrappers.RecordVideo(
@@ -85,7 +85,9 @@ def make_env(gym_id, seed):
def test_record_video_within_vector(): def test_record_video_within_vector():
envs = gym.vector.SyncVectorEnv([make_env("CartPole-v1", 1 + i) for i in range(2)]) envs = gym.vector.SyncVectorEnv(
[make_env("CartPole-v1", 1 + i, render_mode="rgb_array") for i in range(2)]
)
envs = gym.wrappers.RecordEpisodeStatistics(envs) envs = gym.wrappers.RecordEpisodeStatistics(envs)
envs.reset() envs.reset()
for i in range(199): for i in range(199):

View File

@@ -9,21 +9,27 @@ from gym.wrappers.monitoring.video_recorder import VideoRecorder
class BrokenRecordableEnv: class BrokenRecordableEnv:
metadata = {"render_modes": [None, "rgb_array"]} metadata = {"render_modes": ["rgb_array"]}
def render(self, mode=None): def __init__(self, render_mode="rgb_array"):
self.render_mode = render_mode
def render(self, mode="human"):
pass pass
class UnrecordableEnv: class UnrecordableEnv:
metadata = {"render_modes": [None]} metadata = {"render_modes": [None]}
def render(self, mode=None): def __init__(self, render_mode=None):
self.render_mode = render_mode
def render(self, mode="human"):
pass pass
def test_record_simple(): def test_record_simple():
env = gym.make("CartPole-v1") env = gym.make("CartPole-v1", render_mode="rgb_array")
rec = VideoRecorder(env) rec = VideoRecorder(env)
env.reset() env.reset()
rec.capture_frame() rec.capture_frame()
@@ -43,7 +49,7 @@ def test_record_simple():
def test_autoclose(): def test_autoclose():
def record(): def record():
env = gym.make("CartPole-v1") env = gym.make("CartPole-v1", render_mode="rgb_array")
rec = VideoRecorder(env) rec = VideoRecorder(env)
env.reset() env.reset()
rec.capture_frame() rec.capture_frame()
@@ -96,7 +102,7 @@ def test_record_breaking_render_method():
def test_text_envs(): def test_text_envs():
env = gym.make("FrozenLake-v1") env = gym.make("FrozenLake-v1", render_mode="rgb_array")
video = VideoRecorder(env) video = VideoRecorder(env)
try: try:
env.reset() env.reset()