environment agnostic {mode}_list render mode (#3060)

This commit is contained in:
Omar Younis
2022-09-04 15:42:10 +02:00
committed by GitHub
parent f39747d6a2
commit 0608263025
45 changed files with 161 additions and 417 deletions

View File

@@ -9,7 +9,6 @@ import gym
from gym import error, spaces from gym import error, spaces
from gym.error import DependencyNotInstalled from gym.error import DependencyNotInstalled
from gym.utils import EzPickle from gym.utils import EzPickle
from gym.utils.renderer import Renderer
try: try:
import Box2D import Box2D
@@ -164,7 +163,7 @@ class BipedalWalker(gym.Env, EzPickle):
""" """
metadata = { metadata = {
"render_modes": ["human", "rgb_array", "rgb_array_list"], "render_modes": ["human", "rgb_array"],
"render_fps": FPS, "render_fps": FPS,
} }
@@ -258,7 +257,6 @@ class BipedalWalker(gym.Env, EzPickle):
# state += [l.fraction for l in self.lidar] # state += [l.fraction for l in self.lidar]
self.render_mode = render_mode self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
self.screen: Optional[pygame.Surface] = None self.screen: Optional[pygame.Surface] = None
self.clock = None self.clock = None
@@ -512,7 +510,6 @@ class BipedalWalker(gym.Env, EzPickle):
return fraction return fraction
self.lidar = [LidarCallback() for _ in range(10)] self.lidar = [LidarCallback() for _ in range(10)]
self.renderer.reset()
return self.step(np.array([0, 0, 0, 0]))[0], {} return self.step(np.array([0, 0, 0, 0]))[0], {}
def step(self, action: np.ndarray): def step(self, action: np.ndarray):
@@ -601,14 +598,9 @@ class BipedalWalker(gym.Env, EzPickle):
terminated = True terminated = True
if pos[0] > (TERRAIN_LENGTH - TERRAIN_GRASS) * TERRAIN_STEP: if pos[0] > (TERRAIN_LENGTH - TERRAIN_GRASS) * TERRAIN_STEP:
terminated = True terminated = True
self.renderer.render_step()
return np.array(state, dtype=np.float32), reward, terminated, False, {} return np.array(state, dtype=np.float32), reward, terminated, False, {}
def render(self): def render(self):
return self.renderer.get_renders()
def _render(self, mode: str = "human"):
assert mode in self.metadata["render_modes"]
try: try:
import pygame import pygame
from pygame import gfxdraw from pygame import gfxdraw
@@ -617,7 +609,7 @@ class BipedalWalker(gym.Env, EzPickle):
"pygame is not installed, run `pip install gym[box2d]`" "pygame is not installed, run `pip install gym[box2d]`"
) )
if self.screen is None and mode == "human": if self.screen is None and self.render_mode == "human":
pygame.init() pygame.init()
pygame.display.init() pygame.display.init()
self.screen = pygame.display.set_mode((VIEWPORT_W, VIEWPORT_H)) self.screen = pygame.display.set_mode((VIEWPORT_W, VIEWPORT_H))
@@ -736,13 +728,13 @@ class BipedalWalker(gym.Env, EzPickle):
self.surf = pygame.transform.flip(self.surf, False, True) self.surf = pygame.transform.flip(self.surf, False, True)
if mode == "human": if self.render_mode == "human":
assert self.screen is not None assert self.screen is not None
self.screen.blit(self.surf, (-self.scroll * SCALE, 0)) self.screen.blit(self.surf, (-self.scroll * SCALE, 0))
pygame.event.pump() pygame.event.pump()
self.clock.tick(self.metadata["render_fps"]) self.clock.tick(self.metadata["render_fps"])
pygame.display.flip() pygame.display.flip()
elif mode in {"rgb_array", "rgb_array_list"}: elif self.render_mode == "rgb_array":
return np.transpose( return np.transpose(
np.array(pygame.surfarray.pixels3d(self.surf)), axes=(1, 0, 2) np.array(pygame.surfarray.pixels3d(self.surf)), axes=(1, 0, 2)
)[:, -VIEWPORT_W:] )[:, -VIEWPORT_W:]

View File

@@ -10,7 +10,6 @@ from gym import spaces
from gym.envs.box2d.car_dynamics import Car from gym.envs.box2d.car_dynamics import Car
from gym.error import DependencyNotInstalled, InvalidAction from gym.error import DependencyNotInstalled, InvalidAction
from gym.utils import EzPickle from gym.utils import EzPickle
from gym.utils.renderer import Renderer
try: try:
import Box2D import Box2D
@@ -184,8 +183,6 @@ class CarRacing(gym.Env, EzPickle):
metadata = { metadata = {
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array_list",
"state_pixels_list",
"rgb_array", "rgb_array",
"state_pixels", "state_pixels",
], ],
@@ -247,7 +244,6 @@ class CarRacing(gym.Env, EzPickle):
) )
self.render_mode = render_mode self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
def _destroy(self): def _destroy(self):
if not self.road: if not self.road:
@@ -517,7 +513,6 @@ class CarRacing(gym.Env, EzPickle):
) )
self.car = Car(self.world, *self.track[0][1:4]) self.car = Car(self.world, *self.track[0][1:4])
self.renderer.reset()
return self.step(None)[0], {} return self.step(None)[0], {}
def step(self, action: Union[np.ndarray, int]): def step(self, action: Union[np.ndarray, int]):
@@ -563,13 +558,12 @@ class CarRacing(gym.Env, EzPickle):
terminated = True terminated = True
step_reward = -100 step_reward = -100
self.renderer.render_step()
return self.state, step_reward, terminated, truncated, {} return self.state, step_reward, terminated, truncated, {}
def render(self): def render(self):
return self.renderer.get_renders() return self._render(self.render_mode)
def _render(self, mode: str = "human"): def _render(self, mode: str):
assert mode in self.metadata["render_modes"] assert mode in self.metadata["render_modes"]
pygame.font.init() pygame.font.init()
@@ -623,9 +617,9 @@ class CarRacing(gym.Env, EzPickle):
self.screen.blit(self.surf, (0, 0)) self.screen.blit(self.surf, (0, 0))
pygame.display.flip() pygame.display.flip()
if mode in {"rgb_array", "rgb_array_list"}: if mode == "rgb_array":
return self._create_image_array(self.surf, (VIDEO_W, VIDEO_H)) return self._create_image_array(self.surf, (VIDEO_W, VIDEO_H))
elif mode in {"state_pixels_list", "state_pixels"}: elif mode == "state_pixels":
return self._create_image_array(self.surf, (STATE_W, STATE_H)) return self._create_image_array(self.surf, (STATE_W, STATE_H))
else: else:
return self.isopen return self.isopen

View File

@@ -10,7 +10,6 @@ import gym
from gym import error, spaces from gym import error, spaces
from gym.error import DependencyNotInstalled from gym.error import DependencyNotInstalled
from gym.utils import EzPickle, colorize from gym.utils import EzPickle, colorize
from gym.utils.renderer import Renderer
from gym.utils.step_api_compatibility import step_api_compatibility from gym.utils.step_api_compatibility import step_api_compatibility
try: try:
@@ -179,7 +178,7 @@ class LunarLander(gym.Env, EzPickle):
""" """
metadata = { metadata = {
"render_modes": ["human", "rgb_array", "rgb_array_list"], "render_modes": ["human", "rgb_array"],
"render_fps": FPS, "render_fps": FPS,
} }
@@ -287,7 +286,6 @@ class LunarLander(gym.Env, EzPickle):
self.action_space = spaces.Discrete(4) self.action_space = spaces.Discrete(4)
self.render_mode = render_mode self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
def _destroy(self): def _destroy(self):
if not self.moon: if not self.moon:
@@ -411,7 +409,6 @@ class LunarLander(gym.Env, EzPickle):
self.drawlist = [self.lander] + self.legs self.drawlist = [self.lander] + self.legs
self.renderer.reset()
return self.step(np.array([0, 0]) if self.continuous else 0)[0], {} return self.step(np.array([0, 0]) if self.continuous else 0)[0], {}
def _create_particle(self, mass, x, y, ttl): def _create_particle(self, mass, x, y, ttl):
@@ -589,14 +586,9 @@ class LunarLander(gym.Env, EzPickle):
if not self.lander.awake: if not self.lander.awake:
terminated = True terminated = True
reward = +100 reward = +100
self.renderer.render_step()
return np.array(state, dtype=np.float32), reward, terminated, False, {} return np.array(state, dtype=np.float32), reward, terminated, False, {}
def render(self): def render(self):
return self.renderer.get_renders()
def _render(self, mode="human"):
assert mode in self.metadata["render_modes"]
try: try:
import pygame import pygame
from pygame import gfxdraw from pygame import gfxdraw
@@ -605,7 +597,7 @@ class LunarLander(gym.Env, EzPickle):
"pygame is not installed, run `pip install gym[box2d]`" "pygame is not installed, run `pip install gym[box2d]`"
) )
if self.screen is None and mode == "human": if self.screen is None and self.render_mode == "human":
pygame.init() pygame.init()
pygame.display.init() pygame.display.init()
self.screen = pygame.display.set_mode((VIEWPORT_W, VIEWPORT_H)) self.screen = pygame.display.set_mode((VIEWPORT_W, VIEWPORT_H))
@@ -692,13 +684,13 @@ class LunarLander(gym.Env, EzPickle):
self.surf = pygame.transform.flip(self.surf, False, True) self.surf = pygame.transform.flip(self.surf, False, True)
if mode == "human": if self.render_mode == "human":
assert self.screen is not None assert self.screen is not None
self.screen.blit(self.surf, (0, 0)) self.screen.blit(self.surf, (0, 0))
pygame.event.pump() pygame.event.pump()
self.clock.tick(self.metadata["render_fps"]) self.clock.tick(self.metadata["render_fps"])
pygame.display.flip() pygame.display.flip()
elif mode in {"rgb_array", "rgb_array_list"}: elif self.render_mode == "rgb_array":
return np.transpose( return np.transpose(
np.array(pygame.surfarray.pixels3d(self.surf)), axes=(1, 0, 2) np.array(pygame.surfarray.pixels3d(self.surf)), axes=(1, 0, 2)
) )

View File

@@ -21,7 +21,6 @@ __author__ = "Christoph Dann <cdann@cdann.de>"
# SOURCE: # SOURCE:
# https://github.com/rlpy/rlpy/blob/master/rlpy/Domains/Acrobot.py # https://github.com/rlpy/rlpy/blob/master/rlpy/Domains/Acrobot.py
from gym.envs.classic_control import utils from gym.envs.classic_control import utils
from gym.utils.renderer import Renderer
class AcrobotEnv(core.Env): class AcrobotEnv(core.Env):
@@ -137,7 +136,7 @@ class AcrobotEnv(core.Env):
""" """
metadata = { metadata = {
"render_modes": ["human", "rgb_array", "rgb_array_list"], "render_modes": ["human", "rgb_array"],
"render_fps": 15, "render_fps": 15,
} }
@@ -168,7 +167,6 @@ class AcrobotEnv(core.Env):
def __init__(self, render_mode: Optional[str] = None): def __init__(self, render_mode: Optional[str] = None):
self.render_mode = render_mode self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
self.screen = None self.screen = None
self.clock = None self.clock = None
self.isopen = True self.isopen = True
@@ -191,8 +189,6 @@ class AcrobotEnv(core.Env):
np.float32 np.float32
) )
self.renderer.reset()
self.renderer.render_step()
return self._get_ob(), {} return self._get_ob(), {}
def step(self, a): def step(self, a):
@@ -220,7 +216,6 @@ class AcrobotEnv(core.Env):
terminated = self._terminal() terminated = self._terminal()
reward = -1.0 if not terminated else 0.0 reward = -1.0 if not terminated else 0.0
self.renderer.render_step()
return (self._get_ob(), reward, terminated, False, {}) return (self._get_ob(), reward, terminated, False, {})
def _get_ob(self): def _get_ob(self):
@@ -278,10 +273,6 @@ class AcrobotEnv(core.Env):
return dtheta1, dtheta2, ddtheta1, ddtheta2, 0.0 return dtheta1, dtheta2, ddtheta1, ddtheta2, 0.0
def render(self): def render(self):
return self.renderer.get_renders()
def _render(self, mode="human"):
assert mode in self.metadata["render_modes"]
try: try:
import pygame import pygame
from pygame import gfxdraw from pygame import gfxdraw
@@ -292,12 +283,12 @@ class AcrobotEnv(core.Env):
if self.screen is None: if self.screen is None:
pygame.init() pygame.init()
if mode == "human": if self.render_mode == "human":
pygame.display.init() pygame.display.init()
self.screen = pygame.display.set_mode( self.screen = pygame.display.set_mode(
(self.SCREEN_DIM, self.SCREEN_DIM) (self.SCREEN_DIM, self.SCREEN_DIM)
) )
else: # mode in {"rgb_array", "rgb_array_list"} else: # mode in "rgb_array"
self.screen = pygame.Surface((self.SCREEN_DIM, self.SCREEN_DIM)) self.screen = pygame.Surface((self.SCREEN_DIM, self.SCREEN_DIM))
if self.clock is None: if self.clock is None:
self.clock = pygame.time.Clock() self.clock = pygame.time.Clock()
@@ -353,12 +344,12 @@ class AcrobotEnv(core.Env):
surf = pygame.transform.flip(surf, False, True) surf = pygame.transform.flip(surf, False, True)
self.screen.blit(surf, (0, 0)) self.screen.blit(surf, (0, 0))
if mode == "human": if self.render_mode == "human":
pygame.event.pump() pygame.event.pump()
self.clock.tick(self.metadata["render_fps"]) self.clock.tick(self.metadata["render_fps"])
pygame.display.flip() pygame.display.flip()
elif mode in {"rgb_array", "rgb_array_list"}: elif self.render_mode == "rgb_array":
return np.transpose( return np.transpose(
np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2) np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)
) )

View File

@@ -12,7 +12,6 @@ import gym
from gym import logger, spaces from gym import logger, spaces
from gym.envs.classic_control import utils from gym.envs.classic_control import utils
from gym.error import DependencyNotInstalled from gym.error import DependencyNotInstalled
from gym.utils.renderer import Renderer
class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]): class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
@@ -83,7 +82,7 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
""" """
metadata = { metadata = {
"render_modes": ["human", "rgb_array", "rgb_array_list"], "render_modes": ["human", "rgb_array"],
"render_fps": 50, "render_fps": 50,
} }
@@ -118,7 +117,6 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
self.observation_space = spaces.Box(-high, high, dtype=np.float32) self.observation_space = spaces.Box(-high, high, dtype=np.float32)
self.render_mode = render_mode self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
self.screen_width = 600 self.screen_width = 600
self.screen_height = 400 self.screen_height = 400
@@ -185,7 +183,6 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
self.steps_beyond_terminated += 1 self.steps_beyond_terminated += 1
reward = 0.0 reward = 0.0
self.renderer.render_step()
return np.array(self.state, dtype=np.float32), reward, terminated, False, {} return np.array(self.state, dtype=np.float32), reward, terminated, False, {}
def reset( def reset(
@@ -202,15 +199,9 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
) # default high ) # default high
self.state = self.np_random.uniform(low=low, high=high, size=(4,)) self.state = self.np_random.uniform(low=low, high=high, size=(4,))
self.steps_beyond_terminated = None self.steps_beyond_terminated = None
self.renderer.reset()
self.renderer.render_step()
return np.array(self.state, dtype=np.float32), {} return np.array(self.state, dtype=np.float32), {}
def render(self): def render(self):
return self.renderer.get_renders()
def _render(self, mode="human"):
assert mode in self.metadata["render_modes"]
try: try:
import pygame import pygame
from pygame import gfxdraw from pygame import gfxdraw
@@ -221,12 +212,12 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
if self.screen is None: if self.screen is None:
pygame.init() pygame.init()
if mode == "human": if self.render_mode == "human":
pygame.display.init() pygame.display.init()
self.screen = pygame.display.set_mode( self.screen = pygame.display.set_mode(
(self.screen_width, self.screen_height) (self.screen_width, self.screen_height)
) )
else: # mode in {"rgb_array", "rgb_array_list"} else: # mode == "rgb_array"
self.screen = pygame.Surface((self.screen_width, self.screen_height)) self.screen = pygame.Surface((self.screen_width, self.screen_height))
if self.clock is None: if self.clock is None:
self.clock = pygame.time.Clock() self.clock = pygame.time.Clock()
@@ -289,12 +280,12 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
self.surf = pygame.transform.flip(self.surf, False, True) self.surf = pygame.transform.flip(self.surf, False, True)
self.screen.blit(self.surf, (0, 0)) self.screen.blit(self.surf, (0, 0))
if mode == "human": if self.render_mode == "human":
pygame.event.pump() pygame.event.pump()
self.clock.tick(self.metadata["render_fps"]) self.clock.tick(self.metadata["render_fps"])
pygame.display.flip() pygame.display.flip()
elif mode in {"rgb_array", "rgb_array_list"}: elif self.render_mode == "rgb_array":
return np.transpose( return np.transpose(
np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2) np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)
) )

View File

@@ -22,7 +22,6 @@ import gym
from gym import spaces from gym import spaces
from gym.envs.classic_control import utils from gym.envs.classic_control import utils
from gym.error import DependencyNotInstalled from gym.error import DependencyNotInstalled
from gym.utils.renderer import Renderer
class Continuous_MountainCarEnv(gym.Env): class Continuous_MountainCarEnv(gym.Env):
@@ -102,7 +101,7 @@ class Continuous_MountainCarEnv(gym.Env):
""" """
metadata = { metadata = {
"render_modes": ["human", "rgb_array", "rgb_array_list"], "render_modes": ["human", "rgb_array"],
"render_fps": 30, "render_fps": 30,
} }
@@ -126,7 +125,6 @@ class Continuous_MountainCarEnv(gym.Env):
) )
self.render_mode = render_mode self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
self.screen_width = 600 self.screen_width = 600
self.screen_height = 400 self.screen_height = 400
@@ -171,7 +169,6 @@ class Continuous_MountainCarEnv(gym.Env):
reward -= math.pow(action[0], 2) * 0.1 reward -= math.pow(action[0], 2) * 0.1
self.state = np.array([position, velocity], dtype=np.float32) self.state = np.array([position, velocity], dtype=np.float32)
self.renderer.render_step()
return self.state, reward, terminated, False, {} return self.state, reward, terminated, False, {}
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
@@ -180,19 +177,12 @@ class Continuous_MountainCarEnv(gym.Env):
# state/observations. # state/observations.
low, high = utils.maybe_parse_reset_bounds(options, -0.6, -0.4) low, high = utils.maybe_parse_reset_bounds(options, -0.6, -0.4)
self.state = np.array([self.np_random.uniform(low=low, high=high), 0]) self.state = np.array([self.np_random.uniform(low=low, high=high), 0])
self.renderer.reset()
self.renderer.render_step()
return np.array(self.state, dtype=np.float32), {} return np.array(self.state, dtype=np.float32), {}
def _height(self, xs): def _height(self, xs):
return np.sin(3 * xs) * 0.45 + 0.55 return np.sin(3 * xs) * 0.45 + 0.55
def render(self): def render(self):
return self.renderer.get_renders()
def _render(self, mode="human"):
assert mode in self.metadata["render_modes"]
try: try:
import pygame import pygame
from pygame import gfxdraw from pygame import gfxdraw
@@ -203,12 +193,12 @@ class Continuous_MountainCarEnv(gym.Env):
if self.screen is None: if self.screen is None:
pygame.init() pygame.init()
if mode == "human": if self.render_mode == "human":
pygame.display.init() pygame.display.init()
self.screen = pygame.display.set_mode( self.screen = pygame.display.set_mode(
(self.screen_width, self.screen_height) (self.screen_width, self.screen_height)
) )
else: # mode in {"rgb_array", "rgb_array_list"} else: # mode == "rgb_array":
self.screen = pygame.Surface((self.screen_width, self.screen_height)) self.screen = pygame.Surface((self.screen_width, self.screen_height))
if self.clock is None: if self.clock is None:
self.clock = pygame.time.Clock() self.clock = pygame.time.Clock()
@@ -277,12 +267,12 @@ class Continuous_MountainCarEnv(gym.Env):
self.surf = pygame.transform.flip(self.surf, False, True) self.surf = pygame.transform.flip(self.surf, False, True)
self.screen.blit(self.surf, (0, 0)) self.screen.blit(self.surf, (0, 0))
if mode == "human": if self.render_mode == "human":
pygame.event.pump() pygame.event.pump()
self.clock.tick(self.metadata["render_fps"]) self.clock.tick(self.metadata["render_fps"])
pygame.display.flip() pygame.display.flip()
elif mode in {"rgb_array", "rgb_array_list"}: elif self.render_mode == "rgb_array":
return np.transpose( return np.transpose(
np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2) np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)
) )

View File

@@ -11,7 +11,6 @@ import gym
from gym import spaces from gym import spaces
from gym.envs.classic_control import utils from gym.envs.classic_control import utils
from gym.error import DependencyNotInstalled from gym.error import DependencyNotInstalled
from gym.utils.renderer import Renderer
class MountainCarEnv(gym.Env): class MountainCarEnv(gym.Env):
@@ -97,7 +96,7 @@ class MountainCarEnv(gym.Env):
""" """
metadata = { metadata = {
"render_modes": ["human", "rgb_array", "rgb_array_list"], "render_modes": ["human", "rgb_array"],
"render_fps": 30, "render_fps": 30,
} }
@@ -115,7 +114,6 @@ class MountainCarEnv(gym.Env):
self.high = np.array([self.max_position, self.max_speed], dtype=np.float32) self.high = np.array([self.max_position, self.max_speed], dtype=np.float32)
self.render_mode = render_mode self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
self.screen_width = 600 self.screen_width = 600
self.screen_height = 400 self.screen_height = 400
@@ -145,7 +143,6 @@ class MountainCarEnv(gym.Env):
reward = -1.0 reward = -1.0
self.state = (position, velocity) self.state = (position, velocity)
self.renderer.render_step()
return np.array(self.state, dtype=np.float32), reward, terminated, False, {} return np.array(self.state, dtype=np.float32), reward, terminated, False, {}
def reset( def reset(
@@ -159,18 +156,12 @@ class MountainCarEnv(gym.Env):
# state/observations. # state/observations.
low, high = utils.maybe_parse_reset_bounds(options, -0.6, -0.4) low, high = utils.maybe_parse_reset_bounds(options, -0.6, -0.4)
self.state = np.array([self.np_random.uniform(low=low, high=high), 0]) self.state = np.array([self.np_random.uniform(low=low, high=high), 0])
self.renderer.reset()
self.renderer.render_step()
return np.array(self.state, dtype=np.float32), {} return np.array(self.state, dtype=np.float32), {}
def _height(self, xs): def _height(self, xs):
return np.sin(3 * xs) * 0.45 + 0.55 return np.sin(3 * xs) * 0.45 + 0.55
def render(self): def render(self):
return self.renderer.get_renders()
def _render(self, mode="human"):
assert mode in self.metadata["render_modes"]
try: try:
import pygame import pygame
from pygame import gfxdraw from pygame import gfxdraw
@@ -181,12 +172,12 @@ class MountainCarEnv(gym.Env):
if self.screen is None: if self.screen is None:
pygame.init() pygame.init()
if mode == "human": if self.render_mode == "human":
pygame.display.init() pygame.display.init()
self.screen = pygame.display.set_mode( self.screen = pygame.display.set_mode(
(self.screen_width, self.screen_height) (self.screen_width, self.screen_height)
) )
else: # mode in {"rgb_array", "rgb_array_list"} else: # mode in "rgb_array"
self.screen = pygame.Surface((self.screen_width, self.screen_height)) self.screen = pygame.Surface((self.screen_width, self.screen_height))
if self.clock is None: if self.clock is None:
self.clock = pygame.time.Clock() self.clock = pygame.time.Clock()
@@ -255,12 +246,12 @@ class MountainCarEnv(gym.Env):
self.surf = pygame.transform.flip(self.surf, False, True) self.surf = pygame.transform.flip(self.surf, False, True)
self.screen.blit(self.surf, (0, 0)) self.screen.blit(self.surf, (0, 0))
if mode == "human": if self.render_mode == "human":
pygame.event.pump() pygame.event.pump()
self.clock.tick(self.metadata["render_fps"]) self.clock.tick(self.metadata["render_fps"])
pygame.display.flip() pygame.display.flip()
elif mode in {"rgb_array", "rgb_array_list"}: elif self.render_mode == "rgb_array":
return np.transpose( return np.transpose(
np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2) np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)
) )

View File

@@ -9,7 +9,6 @@ import gym
from gym import spaces from gym import spaces
from gym.envs.classic_control import utils from gym.envs.classic_control import utils
from gym.error import DependencyNotInstalled from gym.error import DependencyNotInstalled
from gym.utils.renderer import Renderer
DEFAULT_X = np.pi DEFAULT_X = np.pi
DEFAULT_Y = 1.0 DEFAULT_Y = 1.0
@@ -89,7 +88,7 @@ class PendulumEnv(gym.Env):
""" """
metadata = { metadata = {
"render_modes": ["human", "rgb_array", "rgb_array_list"], "render_modes": ["human", "rgb_array"],
"render_fps": 30, "render_fps": 30,
} }
@@ -102,7 +101,6 @@ class PendulumEnv(gym.Env):
self.l = 1.0 self.l = 1.0
self.render_mode = render_mode self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
self.screen_dim = 500 self.screen_dim = 500
self.screen = None self.screen = None
@@ -135,7 +133,6 @@ class PendulumEnv(gym.Env):
newth = th + newthdot * dt newth = th + newthdot * dt
self.state = np.array([newth, newthdot]) self.state = np.array([newth, newthdot])
self.renderer.render_step()
return self._get_obs(), -costs, False, False, {} return self._get_obs(), -costs, False, False, {}
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
@@ -154,8 +151,6 @@ class PendulumEnv(gym.Env):
self.state = self.np_random.uniform(low=low, high=high) self.state = self.np_random.uniform(low=low, high=high)
self.last_u = None self.last_u = None
self.renderer.reset()
self.renderer.render_step()
return self._get_obs(), {} return self._get_obs(), {}
def _get_obs(self): def _get_obs(self):
@@ -163,10 +158,6 @@ class PendulumEnv(gym.Env):
return np.array([np.cos(theta), np.sin(theta), thetadot], dtype=np.float32) return np.array([np.cos(theta), np.sin(theta), thetadot], dtype=np.float32)
def render(self): def render(self):
return self.renderer.get_renders()
def _render(self, mode="human"):
assert mode in self.metadata["render_modes"]
try: try:
import pygame import pygame
from pygame import gfxdraw from pygame import gfxdraw
@@ -177,12 +168,12 @@ class PendulumEnv(gym.Env):
if self.screen is None: if self.screen is None:
pygame.init() pygame.init()
if mode == "human": if self.render_mode == "human":
pygame.display.init() pygame.display.init()
self.screen = pygame.display.set_mode( self.screen = pygame.display.set_mode(
(self.screen_dim, self.screen_dim) (self.screen_dim, self.screen_dim)
) )
else: # mode in {"rgb_array", "rgb_array_list"} else: # mode in "rgb_array"
self.screen = pygame.Surface((self.screen_dim, self.screen_dim)) self.screen = pygame.Surface((self.screen_dim, self.screen_dim))
if self.clock is None: if self.clock is None:
self.clock = pygame.time.Clock() self.clock = pygame.time.Clock()
@@ -244,12 +235,12 @@ class PendulumEnv(gym.Env):
self.surf = pygame.transform.flip(self.surf, False, True) self.surf = pygame.transform.flip(self.surf, False, True)
self.screen.blit(self.surf, (0, 0)) self.screen.blit(self.surf, (0, 0))
if mode == "human": if self.render_mode == "human":
pygame.event.pump() pygame.event.pump()
self.clock.tick(self.metadata["render_fps"]) self.clock.tick(self.metadata["render_fps"])
pygame.display.flip() pygame.display.flip()
else: # mode == "rgb_array_list": else: # mode == "rgb_array":
return np.transpose( return np.transpose(
np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2) np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)
) )

View File

@@ -10,9 +10,7 @@ class AntEnv(MuJocoPyEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 20, "render_fps": 20,
} }
@@ -31,8 +29,6 @@ class AntEnv(MuJocoPyEnv, utils.EzPickle):
self.do_simulation(a, self.frame_skip) self.do_simulation(a, self.frame_skip)
xposafter = self.get_body_com("torso")[0] xposafter = self.get_body_com("torso")[0]
self.renderer.render_step()
forward_reward = (xposafter - xposbefore) / self.dt forward_reward = (xposafter - xposbefore) / self.dt
ctrl_cost = 0.5 * np.square(a).sum() ctrl_cost = 0.5 * np.square(a).sum()
contact_cost = ( contact_cost = (

View File

@@ -14,9 +14,7 @@ class AntEnv(MuJocoPyEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 20, "render_fps": 20,
} }
@@ -130,8 +128,6 @@ class AntEnv(MuJocoPyEnv, utils.EzPickle):
rewards = forward_reward + healthy_reward rewards = forward_reward + healthy_reward
costs = ctrl_cost + contact_cost costs = ctrl_cost + contact_cost
self.renderer.render_step()
reward = rewards - costs reward = rewards - costs
terminated = self.terminated terminated = self.terminated
observation = self._get_obs() observation = self._get_obs()

View File

@@ -176,9 +176,7 @@ class AntEnv(MujocoEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 20, "render_fps": 20,
} }
@@ -315,7 +313,6 @@ class AntEnv(MujocoEnv, utils.EzPickle):
reward = rewards - costs reward = rewards - costs
self.renderer.render_step()
return observation, reward, terminated, False, info return observation, reward, terminated, False, info
def _get_obs(self): def _get_obs(self):

View File

@@ -10,9 +10,7 @@ class HalfCheetahEnv(MuJocoPyEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 20, "render_fps": 20,
} }
@@ -29,8 +27,6 @@ class HalfCheetahEnv(MuJocoPyEnv, utils.EzPickle):
self.do_simulation(action, self.frame_skip) self.do_simulation(action, self.frame_skip)
xposafter = self.sim.data.qpos[0] xposafter = self.sim.data.qpos[0]
self.renderer.render_step()
ob = self._get_obs() ob = self._get_obs()
reward_ctrl = -0.1 * np.square(action).sum() reward_ctrl = -0.1 * np.square(action).sum()
reward_run = (xposafter - xposbefore) / self.dt reward_run = (xposafter - xposbefore) / self.dt

View File

@@ -16,9 +16,7 @@ class HalfCheetahEnv(MuJocoPyEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 20, "render_fps": 20,
} }
@@ -79,8 +77,6 @@ class HalfCheetahEnv(MuJocoPyEnv, utils.EzPickle):
forward_reward = self._forward_reward_weight * x_velocity forward_reward = self._forward_reward_weight * x_velocity
self.renderer.render_step()
observation = self._get_obs() observation = self._get_obs()
reward = forward_reward - ctrl_cost reward = forward_reward - ctrl_cost
terminated = False terminated = False

View File

@@ -136,9 +136,7 @@ class HalfCheetahEnv(MujocoEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 20, "render_fps": 20,
} }
@@ -207,7 +205,6 @@ class HalfCheetahEnv(MujocoEnv, utils.EzPickle):
"reward_ctrl": -ctrl_cost, "reward_ctrl": -ctrl_cost,
} }
self.renderer.render_step()
return observation, reward, terminated, False, info return observation, reward, terminated, False, info
def _get_obs(self): def _get_obs(self):

View File

@@ -10,9 +10,7 @@ class HopperEnv(MuJocoPyEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 125, "render_fps": 125,
} }
@@ -29,8 +27,6 @@ class HopperEnv(MuJocoPyEnv, utils.EzPickle):
self.do_simulation(a, self.frame_skip) self.do_simulation(a, self.frame_skip)
posafter, height, ang = self.sim.data.qpos[0:3] posafter, height, ang = self.sim.data.qpos[0:3]
self.renderer.render_step()
alive_bonus = 1.0 alive_bonus = 1.0
reward = (posafter - posbefore) / self.dt reward = (posafter - posbefore) / self.dt
reward += alive_bonus reward += alive_bonus

View File

@@ -19,9 +19,7 @@ class HopperEnv(MuJocoPyEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 125, "render_fps": 125,
} }
@@ -142,8 +140,6 @@ class HopperEnv(MuJocoPyEnv, utils.EzPickle):
rewards = forward_reward + healthy_reward rewards = forward_reward + healthy_reward
costs = ctrl_cost costs = ctrl_cost
self.renderer.render_step()
observation = self._get_obs() observation = self._get_obs()
reward = rewards - costs reward = rewards - costs
terminated = self.terminated terminated = self.terminated

View File

@@ -142,9 +142,7 @@ class HopperEnv(MujocoEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 125, "render_fps": 125,
} }
@@ -271,7 +269,6 @@ class HopperEnv(MujocoEnv, utils.EzPickle):
"x_velocity": x_velocity, "x_velocity": x_velocity,
} }
self.renderer.render_step()
return observation, reward, terminated, False, info return observation, reward, terminated, False, info
def reset_model(self): def reset_model(self):

View File

@@ -16,9 +16,7 @@ class HumanoidEnv(MuJocoPyEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 67, "render_fps": 67,
} }
@@ -50,8 +48,6 @@ class HumanoidEnv(MuJocoPyEnv, utils.EzPickle):
self.do_simulation(a, self.frame_skip) self.do_simulation(a, self.frame_skip)
pos_after = mass_center(self.model, self.sim) pos_after = mass_center(self.model, self.sim)
self.renderer.render_step()
alive_bonus = 5.0 alive_bonus = 5.0
data = self.sim.data data = self.sim.data
lin_vel_cost = 1.25 * (pos_after - pos_before) / self.dt lin_vel_cost = 1.25 * (pos_after - pos_before) / self.dt

View File

@@ -23,9 +23,7 @@ class HumanoidEnv(MuJocoPyEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 67, "render_fps": 67,
} }
@@ -157,8 +155,6 @@ class HumanoidEnv(MuJocoPyEnv, utils.EzPickle):
rewards = forward_reward + healthy_reward rewards = forward_reward + healthy_reward
costs = ctrl_cost + contact_cost costs = ctrl_cost + contact_cost
self.renderer.render_step()
observation = self._get_obs() observation = self._get_obs()
reward = rewards - costs reward = rewards - costs
terminated = self.terminated terminated = self.terminated

View File

@@ -216,9 +216,7 @@ class HumanoidEnv(MujocoEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 67, "render_fps": 67,
} }
@@ -348,7 +346,6 @@ class HumanoidEnv(MujocoEnv, utils.EzPickle):
"forward_reward": forward_reward, "forward_reward": forward_reward,
} }
self.renderer.render_step()
return observation, reward, terminated, False, info return observation, reward, terminated, False, info
def reset_model(self): def reset_model(self):

View File

@@ -10,9 +10,7 @@ class HumanoidStandupEnv(MuJocoPyEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 67, "render_fps": 67,
} }
@@ -54,8 +52,6 @@ class HumanoidStandupEnv(MuJocoPyEnv, utils.EzPickle):
quad_impact_cost = min(quad_impact_cost, 10) quad_impact_cost = min(quad_impact_cost, 10)
reward = uph_cost - quad_ctrl_cost - quad_impact_cost + 1 reward = uph_cost - quad_ctrl_cost - quad_impact_cost + 1
self.renderer.render_step()
return ( return (
self._get_obs(), self._get_obs(),
reward, reward,

View File

@@ -182,9 +182,7 @@ class HumanoidStandupEnv(MujocoEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 67, "render_fps": 67,
} }
@@ -226,8 +224,6 @@ class HumanoidStandupEnv(MujocoEnv, utils.EzPickle):
quad_impact_cost = min(quad_impact_cost, 10) quad_impact_cost = min(quad_impact_cost, 10)
reward = uph_cost - quad_ctrl_cost - quad_impact_cost + 1 reward = uph_cost - quad_ctrl_cost - quad_impact_cost + 1
self.renderer.render_step()
return ( return (
self._get_obs(), self._get_obs(),
reward, reward,

View File

@@ -10,9 +10,7 @@ class InvertedDoublePendulumEnv(MuJocoPyEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 20, "render_fps": 20,
} }
@@ -31,8 +29,6 @@ class InvertedDoublePendulumEnv(MuJocoPyEnv, utils.EzPickle):
def step(self, action): def step(self, action):
self.do_simulation(action, self.frame_skip) self.do_simulation(action, self.frame_skip)
self.renderer.render_step()
ob = self._get_obs() ob = self._get_obs()
x, _, y = self.sim.data.site_xpos[0] x, _, y = self.sim.data.site_xpos[0]
dist_penalty = 0.01 * x**2 + (y - 2) ** 2 dist_penalty = 0.01 * x**2 + (y - 2) ** 2

View File

@@ -116,9 +116,7 @@ class InvertedDoublePendulumEnv(MujocoEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 20, "render_fps": 20,
} }
@@ -144,7 +142,6 @@ class InvertedDoublePendulumEnv(MujocoEnv, utils.EzPickle):
alive_bonus = 10 alive_bonus = 10
r = alive_bonus - dist_penalty - vel_penalty r = alive_bonus - dist_penalty - vel_penalty
terminated = bool(y <= 1) terminated = bool(y <= 1)
self.renderer.render_step()
return ob, r, terminated, False, {} return ob, r, terminated, False, {}
def _get_obs(self): def _get_obs(self):

View File

@@ -10,9 +10,7 @@ class InvertedPendulumEnv(MuJocoPyEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 25, "render_fps": 25,
} }
@@ -32,8 +30,6 @@ class InvertedPendulumEnv(MuJocoPyEnv, utils.EzPickle):
reward = 1.0 reward = 1.0
self.do_simulation(a, self.frame_skip) self.do_simulation(a, self.frame_skip)
self.renderer.render_step()
ob = self._get_obs() ob = self._get_obs()
terminated = bool(not np.isfinite(ob).all() or (np.abs(ob[1]) > 0.2)) terminated = bool(not np.isfinite(ob).all() or (np.abs(ob[1]) > 0.2))
return ob, reward, terminated, False, {} return ob, reward, terminated, False, {}

View File

@@ -87,9 +87,7 @@ class InvertedPendulumEnv(MujocoEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 25, "render_fps": 25,
} }
@@ -110,7 +108,6 @@ class InvertedPendulumEnv(MujocoEnv, utils.EzPickle):
self.do_simulation(a, self.frame_skip) self.do_simulation(a, self.frame_skip)
ob = self._get_obs() ob = self._get_obs()
terminated = bool(not np.isfinite(ob).all() or (np.abs(ob[1]) > 0.2)) terminated = bool(not np.isfinite(ob).all() or (np.abs(ob[1]) > 0.2))
self.renderer.render_step()
return ob, reward, terminated, False, {} return ob, reward, terminated, False, {}
def reset_model(self): def reset_model(self):

View File

@@ -1,4 +1,3 @@
from functools import partial
from os import path from os import path
from typing import Optional, Union from typing import Optional, Union
@@ -7,7 +6,6 @@ import numpy as np
import gym import gym
from gym import error, logger, spaces from gym import error, logger, spaces
from gym.spaces import Space from gym.spaces import Space
from gym.utils.renderer import Renderer
MUJOCO_PY_NOT_INSTALLED = False MUJOCO_PY_NOT_INSTALLED = False
MUJOCO_NOT_INSTALLED = False MUJOCO_NOT_INSTALLED = False
@@ -64,9 +62,7 @@ class BaseMujocoEnv(gym.Env):
assert self.metadata["render_modes"] == [ assert self.metadata["render_modes"] == [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], self.metadata["render_modes"] ], self.metadata["render_modes"]
assert ( assert (
int(np.round(1.0 / self.dt)) == self.metadata["render_fps"] int(np.round(1.0 / self.dt)) == self.metadata["render_fps"]
@@ -76,12 +72,8 @@ class BaseMujocoEnv(gym.Env):
self._set_action_space() self._set_action_space()
self.render_mode = render_mode self.render_mode = render_mode
render_frame = partial( self.camera_name = camera_name
self._render, self.camera_id = camera_id
camera_name=camera_name,
camera_id=camera_id,
)
self.renderer = Renderer(self.render_mode, render_frame)
def _set_action_space(self): def _set_action_space(self):
bounds = self.model.actuator_ctrlrange.copy().astype(np.float32) bounds = self.model.actuator_ctrlrange.copy().astype(np.float32)
@@ -123,12 +115,7 @@ class BaseMujocoEnv(gym.Env):
""" """
raise NotImplementedError raise NotImplementedError
def _render( def render(self):
self,
mode: str = "human",
camera_id: Optional[int] = None,
camera_name: Optional[str] = None,
):
""" """
Render a frame from the MuJoCo simulation as specified by the render_mode. Render a frame from the MuJoCo simulation as specified by the render_mode.
""" """
@@ -147,8 +134,6 @@ class BaseMujocoEnv(gym.Env):
self._reset_simulation() self._reset_simulation()
ob = self.reset_model() ob = self.reset_model()
self.renderer.reset()
self.renderer.render_step()
return ob, {} return ob, {}
def set_state(self, qpos, qvel): def set_state(self, qpos, qvel):
@@ -170,9 +155,6 @@ class BaseMujocoEnv(gym.Env):
raise ValueError("Action dimension mismatch") raise ValueError("Action dimension mismatch")
self._step_mujoco_simulation(ctrl, n_frames) self._step_mujoco_simulation(ctrl, n_frames)
def render(self):
return self.renderer.get_renders()
def close(self): def close(self):
if self.viewer is not None: if self.viewer is not None:
self.viewer = None self.viewer = None
@@ -244,20 +226,10 @@ class MuJocoPyEnv(BaseMujocoEnv):
for _ in range(n_frames): for _ in range(n_frames):
self.sim.step() self.sim.step()
def _render( def render(self):
self,
mode: str = "human",
camera_id: Optional[int] = None,
camera_name: Optional[str] = None,
):
width, height = self.width, self.height width, height = self.width, self.height
assert mode in self.metadata["render_modes"] camera_name, camera_id = self.camera_name, self.camera_id
if mode in { if self.render_mode in {"rgb_array", "depth_array"}:
"rgb_array",
"rgb_array_list",
"depth_array",
"depth_array_list",
}:
if camera_id is not None and camera_name is not None: if camera_id is not None and camera_name is not None:
raise ValueError( raise ValueError(
"Both `camera_id` and `camera_name` cannot be" "Both `camera_id` and `camera_name` cannot be"
@@ -272,20 +244,26 @@ class MuJocoPyEnv(BaseMujocoEnv):
if camera_name in self.model._camera_name2id: if camera_name in self.model._camera_name2id:
camera_id = self.model.camera_name2id(camera_name) camera_id = self.model.camera_name2id(camera_name)
self._get_viewer(mode).render(width, height, camera_id=camera_id) self._get_viewer(self.render_mode).render(
width, height, camera_id=camera_id
)
if mode in {"rgb_array", "rgb_array_list"}: if self.render_mode == "rgb_array":
data = self._get_viewer(mode).read_pixels(width, height, depth=False) data = self._get_viewer(self.render_mode).read_pixels(
width, height, depth=False
)
# original image is upside-down, so flip it # original image is upside-down, so flip it
return data[::-1, :, :] return data[::-1, :, :]
elif mode in {"depth_array_list", "depth_array"}: elif self.render_mode == "depth_array":
self._get_viewer(mode).render(width, height) self._get_viewer(self.render_mode).render(width, height)
# Extract depth part of the read_pixels() tuple # Extract depth part of the read_pixels() tuple
data = self._get_viewer(mode).read_pixels(width, height, depth=True)[1] data = self._get_viewer(self.render_mode).read_pixels(
width, height, depth=True
)[1]
# original image is upside-down, so flip it # original image is upside-down, so flip it
return data[::-1, :] return data[::-1, :]
elif mode == "human": elif self.render_mode == "human":
self._get_viewer(mode).render() self._get_viewer(self.render_mode).render()
def _get_viewer( def _get_viewer(
self, mode self, mode
@@ -295,12 +273,7 @@ class MuJocoPyEnv(BaseMujocoEnv):
if mode == "human": if mode == "human":
self.viewer = mujoco_py.MjViewer(self.sim) self.viewer = mujoco_py.MjViewer(self.sim)
elif mode in { elif mode in {"rgb_array", "depth_array"}:
"rgb_array",
"depth_array",
"rgb_array_list",
"depth_array_list",
}:
self.viewer = mujoco_py.MjRenderContextOffscreen(self.sim, -1) self.viewer = mujoco_py.MjRenderContextOffscreen(self.sim, -1)
else: else:
raise AttributeError( raise AttributeError(
@@ -373,20 +346,14 @@ class MujocoEnv(BaseMujocoEnv):
# See https://github.com/openai/gym/issues/1541 # See https://github.com/openai/gym/issues/1541
mujoco.mj_rnePostConstraint(self.model, self.data) mujoco.mj_rnePostConstraint(self.model, self.data)
def _render( def render(self):
self, if self.render_mode in {
mode: str = "human",
camera_id: Optional[int] = None,
camera_name: Optional[str] = None,
):
assert mode in self.metadata["render_modes"]
if mode in {
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
}: }:
camera_id = self.camera_id
camera_name = self.camera_name
if camera_id is not None and camera_name is not None: if camera_id is not None and camera_name is not None:
raise ValueError( raise ValueError(
"Both `camera_id` and `camera_name` cannot be" "Both `camera_id` and `camera_name` cannot be"
@@ -404,20 +371,20 @@ class MujocoEnv(BaseMujocoEnv):
camera_name, camera_name,
) )
self._get_viewer(mode).render(camera_id=camera_id) self._get_viewer(self.render_mode).render(camera_id=camera_id)
if mode in {"rgb_array", "rgb_array_list"}: if self.render_mode == "rgb_array":
data = self._get_viewer(mode).read_pixels(depth=False) data = self._get_viewer(self.render_mode).read_pixels(depth=False)
# original image is upside-down, so flip it # original image is upside-down, so flip it
return data[::-1, :, :] return data[::-1, :, :]
elif mode in {"depth_array", "depth_array_list"}: elif self.render_mode == "depth_array":
self._get_viewer(mode).render() self._get_viewer(self.render_mode).render()
# Extract depth part of the read_pixels() tuple # Extract depth part of the read_pixels() tuple
data = self._get_viewer(mode).read_pixels(depth=True)[1] data = self._get_viewer(self.render_mode).read_pixels(depth=True)[1]
# original image is upside-down, so flip it # original image is upside-down, so flip it
return data[::-1, :] return data[::-1, :]
elif mode == "human": elif self.render_mode == "human":
self._get_viewer(mode).render() self._get_viewer(self.render_mode).render()
def close(self): def close(self):
if self.viewer is not None: if self.viewer is not None:
@@ -433,12 +400,7 @@ class MujocoEnv(BaseMujocoEnv):
from gym.envs.mujoco import Viewer from gym.envs.mujoco import Viewer
self.viewer = Viewer(self.model, self.data) self.viewer = Viewer(self.model, self.data)
elif mode in { elif mode in {"rgb_array", "depth_array"}:
"rgb_array",
"depth_array",
"rgb_array_list",
"depth_array_list",
}:
from gym.envs.mujoco import RenderContextOffscreen from gym.envs.mujoco import RenderContextOffscreen
self.viewer = RenderContextOffscreen(self.model, self.data) self.viewer = RenderContextOffscreen(self.model, self.data)

View File

@@ -10,9 +10,7 @@ class PusherEnv(MuJocoPyEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 20, "render_fps": 20,
} }
@@ -35,8 +33,6 @@ class PusherEnv(MuJocoPyEnv, utils.EzPickle):
self.do_simulation(a, self.frame_skip) self.do_simulation(a, self.frame_skip)
self.renderer.render_step()
ob = self._get_obs() ob = self._get_obs()
return ( return (
ob, ob,

View File

@@ -132,9 +132,7 @@ class PusherEnv(MujocoEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 20, "render_fps": 20,
} }
@@ -157,7 +155,6 @@ class PusherEnv(MujocoEnv, utils.EzPickle):
self.do_simulation(a, self.frame_skip) self.do_simulation(a, self.frame_skip)
ob = self._get_obs() ob = self._get_obs()
self.renderer.render_step()
return ( return (
ob, ob,
reward, reward,

View File

@@ -10,9 +10,7 @@ class ReacherEnv(MuJocoPyEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 50, "render_fps": 50,
} }
@@ -31,8 +29,6 @@ class ReacherEnv(MuJocoPyEnv, utils.EzPickle):
reward = reward_dist + reward_ctrl reward = reward_dist + reward_ctrl
self.do_simulation(a, self.frame_skip) self.do_simulation(a, self.frame_skip)
self.renderer.render_step()
ob = self._get_obs() ob = self._get_obs()
return ( return (
ob, ob,

View File

@@ -122,9 +122,7 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 50, "render_fps": 50,
} }
@@ -143,7 +141,6 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
reward = reward_dist + reward_ctrl reward = reward_dist + reward_ctrl
self.do_simulation(a, self.frame_skip) self.do_simulation(a, self.frame_skip)
ob = self._get_obs() ob = self._get_obs()
self.renderer.render_step()
return ( return (
ob, ob,
reward, reward,

View File

@@ -10,9 +10,7 @@ class SwimmerEnv(MuJocoPyEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 25, "render_fps": 25,
} }
@@ -30,8 +28,6 @@ class SwimmerEnv(MuJocoPyEnv, utils.EzPickle):
self.do_simulation(a, self.frame_skip) self.do_simulation(a, self.frame_skip)
xposafter = self.sim.data.qpos[0] xposafter = self.sim.data.qpos[0]
self.renderer.render_step()
reward_fwd = (xposafter - xposbefore) / self.dt reward_fwd = (xposafter - xposbefore) / self.dt
reward_ctrl = -ctrl_cost_coeff * np.square(a).sum() reward_ctrl = -ctrl_cost_coeff * np.square(a).sum()
reward = reward_fwd + reward_ctrl reward = reward_fwd + reward_ctrl

View File

@@ -14,9 +14,7 @@ class SwimmerEnv(MuJocoPyEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 25, "render_fps": 25,
} }
@@ -71,8 +69,6 @@ class SwimmerEnv(MuJocoPyEnv, utils.EzPickle):
self.do_simulation(action, self.frame_skip) self.do_simulation(action, self.frame_skip)
xy_position_after = self.sim.data.qpos[0:2].copy() xy_position_after = self.sim.data.qpos[0:2].copy()
self.renderer.render_step()
xy_velocity = (xy_position_after - xy_position_before) / self.dt xy_velocity = (xy_position_after - xy_position_before) / self.dt
x_velocity, y_velocity = xy_velocity x_velocity, y_velocity = xy_velocity

View File

@@ -128,9 +128,7 @@ class SwimmerEnv(MujocoEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 25, "render_fps": 25,
} }
@@ -201,7 +199,6 @@ class SwimmerEnv(MujocoEnv, utils.EzPickle):
"forward_reward": forward_reward, "forward_reward": forward_reward,
} }
self.renderer.render_step()
return observation, reward, False, False, info return observation, reward, False, False, info
def _get_obs(self): def _get_obs(self):

View File

@@ -10,9 +10,7 @@ class Walker2dEnv(MuJocoPyEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 125, "render_fps": 125,
} }
@@ -29,8 +27,6 @@ class Walker2dEnv(MuJocoPyEnv, utils.EzPickle):
self.do_simulation(a, self.frame_skip) self.do_simulation(a, self.frame_skip)
posafter, height, ang = self.sim.data.qpos[0:3] posafter, height, ang = self.sim.data.qpos[0:3]
self.renderer.render_step()
alive_bonus = 1.0 alive_bonus = 1.0
reward = (posafter - posbefore) / self.dt reward = (posafter - posbefore) / self.dt
reward += alive_bonus reward += alive_bonus

View File

@@ -17,9 +17,7 @@ class Walker2dEnv(MuJocoPyEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 125, "render_fps": 125,
} }
@@ -124,8 +122,6 @@ class Walker2dEnv(MuJocoPyEnv, utils.EzPickle):
x_position_after = self.sim.data.qpos[0] x_position_after = self.sim.data.qpos[0]
x_velocity = (x_position_after - x_position_before) / self.dt x_velocity = (x_position_after - x_position_before) / self.dt
self.renderer.render_step()
ctrl_cost = self.control_cost(action) ctrl_cost = self.control_cost(action)
forward_reward = self._forward_reward_weight * x_velocity forward_reward = self._forward_reward_weight * x_velocity
healthy_reward = self.healthy_reward healthy_reward = self.healthy_reward

View File

@@ -147,9 +147,7 @@ class Walker2dEnv(MujocoEnv, utils.EzPickle):
"render_modes": [ "render_modes": [
"human", "human",
"rgb_array", "rgb_array",
"rgb_array_list",
"depth_array", "depth_array",
"depth_array_list",
], ],
"render_fps": 125, "render_fps": 125,
} }
@@ -268,7 +266,6 @@ class Walker2dEnv(MujocoEnv, utils.EzPickle):
"x_velocity": x_velocity, "x_velocity": x_velocity,
} }
self.renderer.render_step()
return observation, reward, terminated, False, info return observation, reward, terminated, False, info
def reset_model(self): def reset_model(self):

View File

@@ -25,6 +25,7 @@ from gym.wrappers import (
AutoResetWrapper, AutoResetWrapper,
HumanRendering, HumanRendering,
OrderEnforcing, OrderEnforcing,
RenderCollection,
StepAPICompatibility, StepAPICompatibility,
TimeLimit, TimeLimit,
) )
@@ -581,6 +582,7 @@ def make(
mode = _kwargs.get("render_mode") mode = _kwargs.get("render_mode")
apply_human_rendering = False apply_human_rendering = False
apply_render_collection = False
# If we have access to metadata we check that "render_mode" is valid and see if the HumanRendering wrapper needs to be applied # If we have access to metadata we check that "render_mode" is valid and see if the HumanRendering wrapper needs to be applied
if mode is not None and hasattr(env_creator, "metadata"): if mode is not None and hasattr(env_creator, "metadata"):
@@ -610,6 +612,13 @@ def make(
_kwargs["render_mode"] = "rgb_array" _kwargs["render_mode"] = "rgb_array"
else: else:
_kwargs["render_mode"] = "rgb_array_list" _kwargs["render_mode"] = "rgb_array_list"
elif (
mode not in render_modes
and mode.endswith("_list")
and mode[: -len("_list")] in render_modes
):
_kwargs["render_mode"] = mode[: -len("_list")]
apply_render_collection = True
elif mode not in render_modes: elif mode not in render_modes:
logger.warn( logger.warn(
f"The environment is being initialised with mode ({mode}) that is not in the possible render_modes ({render_modes})." f"The environment is being initialised with mode ({mode}) that is not in the possible render_modes ({render_modes})."
@@ -668,6 +677,8 @@ def make(
# Add human rendering wrapper # Add human rendering wrapper
if apply_human_rendering: if apply_human_rendering:
env = HumanRendering(env) env = HumanRendering(env)
elif apply_render_collection:
env = RenderCollection(env)
return env return env

View File

@@ -6,7 +6,6 @@ import numpy as np
import gym import gym
from gym import spaces from gym import spaces
from gym.error import DependencyNotInstalled from gym.error import DependencyNotInstalled
from gym.utils.renderer import Renderer
def cmp(a, b): def cmp(a, b):
@@ -112,7 +111,7 @@ class BlackjackEnv(gym.Env):
""" """
metadata = { metadata = {
"render_modes": ["human", "rgb_array", "rgb_array_list"], "render_modes": ["human", "rgb_array"],
"render_fps": 4, "render_fps": 4,
} }
@@ -130,7 +129,6 @@ class BlackjackEnv(gym.Env):
self.sab = sab self.sab = sab
self.render_mode = render_mode self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
def step(self, action): def step(self, action):
assert self.action_space.contains(action) assert self.action_space.contains(action)
@@ -158,7 +156,6 @@ class BlackjackEnv(gym.Env):
): ):
# Natural gives extra points, but doesn't autowin. Legacy implementation # Natural gives extra points, but doesn't autowin. Legacy implementation
reward = 1.5 reward = 1.5
self.renderer.render_step()
return self._get_obs(), reward, terminated, False, {} return self._get_obs(), reward, terminated, False, {}
def _get_obs(self): def _get_obs(self):
@@ -185,17 +182,9 @@ class BlackjackEnv(gym.Env):
else: else:
self.dealer_top_card_value_str = str(dealer_card_value) self.dealer_top_card_value_str = str(dealer_card_value)
self.renderer.reset()
self.renderer.render_step()
return self._get_obs(), {} return self._get_obs(), {}
def render(self): def render(self):
return self.renderer.get_renders()
def _render(self, mode: str = "human"):
assert mode in self.metadata["render_modes"]
try: try:
import pygame import pygame
except ImportError: except ImportError:
@@ -214,7 +203,7 @@ class BlackjackEnv(gym.Env):
if not hasattr(self, "screen"): if not hasattr(self, "screen"):
pygame.init() pygame.init()
if mode == "human": if self.render_mode == "human":
pygame.display.init() pygame.display.init()
self.screen = pygame.display.set_mode((screen_width, screen_height)) self.screen = pygame.display.set_mode((screen_width, screen_height))
else: else:
@@ -296,7 +285,7 @@ class BlackjackEnv(gym.Env):
player_sum_text_rect.bottom + spacing // 2, player_sum_text_rect.bottom + spacing // 2,
), ),
) )
if mode == "human": if self.render_mode == "human":
pygame.event.pump() pygame.event.pump()
pygame.display.update() pygame.display.update()
self.clock.tick(self.metadata["render_fps"]) self.clock.tick(self.metadata["render_fps"])

View File

@@ -7,7 +7,7 @@ import numpy as np
from gym import Env, spaces from gym import Env, spaces
from gym.envs.toy_text.utils import categorical_sample from gym.envs.toy_text.utils import categorical_sample
from gym.utils.renderer import Renderer from gym.error import DependencyNotInstalled
UP = 0 UP = 0
RIGHT = 1 RIGHT = 1
@@ -63,7 +63,7 @@ class CliffWalkingEnv(Env):
""" """
metadata = { metadata = {
"render_modes": ["human", "rgb_array", "rgb_array_list", "ansi"], "render_modes": ["human", "rgb_array", "ansi"],
"render_fps": 4, "render_fps": 4,
} }
@@ -97,7 +97,6 @@ class CliffWalkingEnv(Env):
self.action_space = spaces.Discrete(self.nA) self.action_space = spaces.Discrete(self.nA)
self.render_mode = render_mode self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
# pygame utils # pygame utils
self.cell_size = (60, 60) self.cell_size = (60, 60)
@@ -149,30 +148,28 @@ class CliffWalkingEnv(Env):
p, s, r, t = transitions[i] p, s, r, t = transitions[i]
self.s = s self.s = s
self.lastaction = a self.lastaction = a
self.renderer.render_step()
return (int(s), r, t, False, {"prob": p}) return (int(s), r, t, False, {"prob": p})
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed) super().reset(seed=seed)
self.s = categorical_sample(self.initial_state_distrib, self.np_random) self.s = categorical_sample(self.initial_state_distrib, self.np_random)
self.lastaction = None self.lastaction = None
self.renderer.reset()
self.renderer.render_step()
return int(self.s), {"prob": 1} return int(self.s), {"prob": 1}
def render(self): def render(self):
return self.renderer.get_renders() if self.render_mode == "ansi":
def _render(self, mode="human"):
if mode == "ansi":
return self._render_text() return self._render_text()
else: else:
return self._render_gui(mode) return self._render_gui(self.render_mode)
def _render_gui(self, mode): def _render_gui(self, mode):
try:
import pygame import pygame
except ImportError:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gym[toy_text]`"
)
if self.window_surface is None: if self.window_surface is None:
pygame.init() pygame.init()

View File

@@ -8,7 +8,6 @@ import numpy as np
from gym import Env, spaces, utils from gym import Env, spaces, utils
from gym.envs.toy_text.utils import categorical_sample from gym.envs.toy_text.utils import categorical_sample
from gym.error import DependencyNotInstalled from gym.error import DependencyNotInstalled
from gym.utils.renderer import Renderer
LEFT = 0 LEFT = 0
DOWN = 1 DOWN = 1
@@ -156,7 +155,7 @@ class FrozenLakeEnv(Env):
""" """
metadata = { metadata = {
"render_modes": ["human", "ansi", "rgb_array", "rgb_array_list"], "render_modes": ["human", "ansi", "rgb_array"],
"render_fps": 4, "render_fps": 4,
} }
@@ -226,7 +225,6 @@ class FrozenLakeEnv(Env):
self.action_space = spaces.Discrete(nA) self.action_space = spaces.Discrete(nA)
self.render_mode = render_mode self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
# pygame utils # pygame utils
self.window_size = (min(64 * ncol, 512), min(64 * nrow, 512)) self.window_size = (min(64 * ncol, 512), min(64 * nrow, 512))
@@ -249,7 +247,6 @@ class FrozenLakeEnv(Env):
p, s, r, t = transitions[i] p, s, r, t = transitions[i]
self.s = s self.s = s
self.lastaction = a self.lastaction = a
self.renderer.render_step()
return (int(s), r, t, False, {"prob": p}) return (int(s), r, t, False, {"prob": p})
def reset( def reset(
@@ -262,20 +259,13 @@ class FrozenLakeEnv(Env):
self.s = categorical_sample(self.initial_state_distrib, self.np_random) self.s = categorical_sample(self.initial_state_distrib, self.np_random)
self.lastaction = None self.lastaction = None
self.renderer.reset()
self.renderer.render_step()
return int(self.s), {"prob": 1} return int(self.s), {"prob": 1}
def render(self): def render(self):
return self.renderer.get_renders() if self.render_mode == "ansi":
def _render(self, mode="human"):
assert mode in self.metadata["render_modes"]
if mode == "ansi":
return self._render_text() return self._render_text()
elif mode in {"human", "rgb_array", "rgb_array_list"}: else: # self.render_mode in {"human", "rgb_array"}:
return self._render_gui(mode) return self._render_gui(self.render_mode)
def _render_gui(self, mode): def _render_gui(self, mode):
try: try:
@@ -292,7 +282,7 @@ class FrozenLakeEnv(Env):
pygame.display.init() pygame.display.init()
pygame.display.set_caption("Frozen Lake") pygame.display.set_caption("Frozen Lake")
self.window_surface = pygame.display.set_mode(self.window_size) self.window_surface = pygame.display.set_mode(self.window_size)
elif mode in {"rgb_array", "rgb_array_list"}: elif mode == "rgb_array":
self.window_surface = pygame.Surface(self.window_size) self.window_surface = pygame.Surface(self.window_size)
assert ( assert (
@@ -370,7 +360,7 @@ class FrozenLakeEnv(Env):
pygame.event.pump() pygame.event.pump()
pygame.display.update() pygame.display.update()
self.clock.tick(self.metadata["render_fps"]) self.clock.tick(self.metadata["render_fps"])
elif mode in {"rgb_array", "rgb_array_list"}: elif mode == "rgb_array":
return np.transpose( return np.transpose(
np.array(pygame.surfarray.pixels3d(self.window_surface)), axes=(1, 0, 2) np.array(pygame.surfarray.pixels3d(self.window_surface)), axes=(1, 0, 2)
) )

View File

@@ -8,7 +8,6 @@ import numpy as np
from gym import Env, spaces, utils from gym import Env, spaces, utils
from gym.envs.toy_text.utils import categorical_sample from gym.envs.toy_text.utils import categorical_sample
from gym.error import DependencyNotInstalled from gym.error import DependencyNotInstalled
from gym.utils.renderer import Renderer
MAP = [ MAP = [
"+---------+", "+---------+",
@@ -122,7 +121,7 @@ class TaxiEnv(Env):
""" """
metadata = { metadata = {
"render_modes": ["human", "ansi", "rgb_array", "rgb_array_list"], "render_modes": ["human", "ansi", "rgb_array"],
"render_fps": 4, "render_fps": 4,
} }
@@ -192,7 +191,6 @@ class TaxiEnv(Env):
self.observation_space = spaces.Discrete(num_states) self.observation_space = spaces.Discrete(num_states)
self.render_mode = render_mode self.render_mode = render_mode
self.renderer = Renderer(self.render_mode, self._render)
# pygame utils # pygame utils
self.window = None self.window = None
@@ -259,7 +257,6 @@ class TaxiEnv(Env):
p, s, r, t = transitions[i] p, s, r, t = transitions[i]
self.s = s self.s = s
self.lastaction = a self.lastaction = a
self.renderer.render_step()
return (int(s), r, t, False, {"prob": p, "action_mask": self.action_mask(s)}) return (int(s), r, t, False, {"prob": p, "action_mask": self.action_mask(s)})
def reset( def reset(
@@ -272,20 +269,14 @@ class TaxiEnv(Env):
self.s = categorical_sample(self.initial_state_distrib, self.np_random) self.s = categorical_sample(self.initial_state_distrib, self.np_random)
self.lastaction = None self.lastaction = None
self.taxi_orientation = 0 self.taxi_orientation = 0
self.renderer.reset()
self.renderer.render_step()
return int(self.s), {"prob": 1.0, "action_mask": self.action_mask(self.s)} return int(self.s), {"prob": 1.0, "action_mask": self.action_mask(self.s)}
def render(self): def render(self):
return self.renderer.get_renders() if self.render_mode == "ansi":
def _render(self, mode):
assert mode in self.metadata["render_modes"]
if mode == "ansi":
return self._render_text() return self._render_text()
elif mode in {"human", "rgb_array", "rgb_array_list"}: else: # self.render_mode in {"human", "rgb_array"}:
return self._render_gui(mode) return self._render_gui(self.render_mode)
def _render_gui(self, mode): def _render_gui(self, mode):
try: try:
@@ -300,7 +291,7 @@ class TaxiEnv(Env):
pygame.display.set_caption("Taxi") pygame.display.set_caption("Taxi")
if mode == "human": if mode == "human":
self.window = pygame.display.set_mode(WINDOW_SIZE) self.window = pygame.display.set_mode(WINDOW_SIZE)
elif mode in {"rgb_array", "rgb_array_list"}: elif mode == "rgb_array":
self.window = pygame.Surface(WINDOW_SIZE) self.window = pygame.Surface(WINDOW_SIZE)
assert ( assert (
@@ -412,7 +403,7 @@ class TaxiEnv(Env):
if mode == "human": if mode == "human":
pygame.display.update() pygame.display.update()
self.clock.tick(self.metadata["render_fps"]) self.clock.tick(self.metadata["render_fps"])
elif mode in {"rgb_array", "rgb_array_list"}: elif mode == "rgb_array":
return np.transpose( return np.transpose(
np.array(pygame.surfarray.pixels3d(self.window)), axes=(1, 0, 2) np.array(pygame.surfarray.pixels3d(self.window)), axes=(1, 0, 2)
) )

View File

@@ -1,79 +0,0 @@
"""A utility class to collect render frames from a function that computes a single frame."""
from typing import Any, Callable, List, Optional, Set
# list of modes with which render function returns None
NO_RETURNS_RENDER = {"human"}
# list of modes with which render returns just a single frame of the current state
SINGLE_RENDER = {"rgb_array", "depth_array", "state_pixels", "ansi"}
class Renderer:
"""This class serves to easily integrate collection of renders for environments that can computes a single render.
To use this function:
- instantiate this class with the mode and the function that computes a single frame
- call render_step method each time the frame should be saved in the list
(usually at the end of the step and reset methods)
- call get_renders whenever you want to retrieve renders
(usually in the render method)
- call reset to clean the render list
(usually in the reset method of the environment)
"""
def __init__(
self,
mode: Optional[str],
render: Callable[[str], Any],
no_returns_render: Optional[Set[str]] = None,
single_render: Optional[Set[str]] = None,
):
"""Instantiates a Renderer object.
Args:
mode (Optional[str]): Way to render
render (Callable[[str], Any]): Function that receives the mode and computes a single frame
no_returns_render (Optional[Set[str]]): Set of render modes that don't return any value.
The default value is the set {"human"}.
single_render (Optional[Set[str]]): Set of render modes that should return a single frame.
The default value is the set {"rgb_array", "depth_array", "state_pixels", "ansi"}.
"""
if no_returns_render is None:
no_returns_render = NO_RETURNS_RENDER
if single_render is None:
single_render = SINGLE_RENDER
self.no_returns_render = no_returns_render
self.single_render = single_render
self.mode = mode
self.render = render
self.render_list = []
def render_step(self) -> None:
"""Computes a frame and save it to the render collection list.
This method should be usually called inside environment's step and reset method.
"""
if self.mode is not None and self.mode not in self.single_render:
render_return = self.render(self.mode)
if self.mode not in self.no_returns_render:
self.render_list.append(render_return)
def get_renders(self) -> Optional[List]:
"""Pops all the frames from the render collection list.
This method should be usually called in the environment's render method to retrieve the frames collected till this time step.
"""
if self.mode in self.single_render:
return self.render(self.mode)
elif self.mode is not None and self.mode not in self.no_returns_render:
renders = self.render_list
self.render_list = []
return renders
def reset(self):
"""Resets the render collection list.
This method should be usually called inside environment's reset method.
"""
self.render_list = []

View File

@@ -12,6 +12,7 @@ from gym.wrappers.normalize import NormalizeObservation, NormalizeReward
from gym.wrappers.order_enforcing import OrderEnforcing from gym.wrappers.order_enforcing import OrderEnforcing
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
from gym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule from gym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule
from gym.wrappers.render_collection import RenderCollection
from gym.wrappers.rescale_action import RescaleAction from gym.wrappers.rescale_action import RescaleAction
from gym.wrappers.resize_observation import ResizeObservation from gym.wrappers.resize_observation import ResizeObservation
from gym.wrappers.step_api_compatibility import StepAPICompatibility from gym.wrappers.step_api_compatibility import StepAPICompatibility

View File

@@ -0,0 +1,52 @@
"""A wrapper that adds render collection mode to an environment."""
import gym
class RenderCollection(gym.Wrapper):
"""Save collection of render frames."""
def __init__(self, env: gym.Env, pop_frames: bool = True, reset_clean: bool = True):
"""Initialize a :class:`RenderCollection` instance.
Args:
env: The environment that is being wrapped
pop_frames (bool): If true, clear the collection frames after .render() is called.
Default value is True.
reset_clean (bool): If true, clear the collection frames when .reset() is called.
Default value is True.
"""
super().__init__(env)
assert env.render_mode is not None
assert not env.render_mode.endswith("_list")
self.frame_list = []
self.reset_clean = reset_clean
self.pop_frames = pop_frames
@property
def render_mode(self):
"""Returns the collection render_mode name."""
return f"{self.env.render_mode}_list"
def step(self, *args, **kwargs):
"""Perform a step in the base environment and collect a frame."""
output = self.env.step(*args, **kwargs)
self.frame_list.append(self.env.render())
return output
def reset(self, *args, **kwargs):
"""Reset the base environment, eventually clear the frame_list, and collect a frame."""
result = self.env.reset(*args, **kwargs)
if self.reset_clean:
self.frame_list = []
self.frame_list.append(self.env.render())
return result
def render(self):
"""Returns the collection of frames and, if pop_frames = True, clears it."""
frames = self.frame_list
if self.pop_frames:
self.frame_list = []
return frames