mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 22:11:25 +00:00
environment agnostic {mode}_list render mode (#3060)
This commit is contained in:
@@ -9,7 +9,6 @@ import gym
|
||||
from gym import error, spaces
|
||||
from gym.error import DependencyNotInstalled
|
||||
from gym.utils import EzPickle
|
||||
from gym.utils.renderer import Renderer
|
||||
|
||||
try:
|
||||
import Box2D
|
||||
@@ -164,7 +163,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
"""
|
||||
|
||||
metadata = {
|
||||
"render_modes": ["human", "rgb_array", "rgb_array_list"],
|
||||
"render_modes": ["human", "rgb_array"],
|
||||
"render_fps": FPS,
|
||||
}
|
||||
|
||||
@@ -258,7 +257,6 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
# state += [l.fraction for l in self.lidar]
|
||||
|
||||
self.render_mode = render_mode
|
||||
self.renderer = Renderer(self.render_mode, self._render)
|
||||
self.screen: Optional[pygame.Surface] = None
|
||||
self.clock = None
|
||||
|
||||
@@ -512,7 +510,6 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
return fraction
|
||||
|
||||
self.lidar = [LidarCallback() for _ in range(10)]
|
||||
self.renderer.reset()
|
||||
return self.step(np.array([0, 0, 0, 0]))[0], {}
|
||||
|
||||
def step(self, action: np.ndarray):
|
||||
@@ -601,14 +598,9 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
terminated = True
|
||||
if pos[0] > (TERRAIN_LENGTH - TERRAIN_GRASS) * TERRAIN_STEP:
|
||||
terminated = True
|
||||
self.renderer.render_step()
|
||||
return np.array(state, dtype=np.float32), reward, terminated, False, {}
|
||||
|
||||
def render(self):
|
||||
return self.renderer.get_renders()
|
||||
|
||||
def _render(self, mode: str = "human"):
|
||||
assert mode in self.metadata["render_modes"]
|
||||
try:
|
||||
import pygame
|
||||
from pygame import gfxdraw
|
||||
@@ -617,7 +609,7 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
"pygame is not installed, run `pip install gym[box2d]`"
|
||||
)
|
||||
|
||||
if self.screen is None and mode == "human":
|
||||
if self.screen is None and self.render_mode == "human":
|
||||
pygame.init()
|
||||
pygame.display.init()
|
||||
self.screen = pygame.display.set_mode((VIEWPORT_W, VIEWPORT_H))
|
||||
@@ -736,13 +728,13 @@ class BipedalWalker(gym.Env, EzPickle):
|
||||
|
||||
self.surf = pygame.transform.flip(self.surf, False, True)
|
||||
|
||||
if mode == "human":
|
||||
if self.render_mode == "human":
|
||||
assert self.screen is not None
|
||||
self.screen.blit(self.surf, (-self.scroll * SCALE, 0))
|
||||
pygame.event.pump()
|
||||
self.clock.tick(self.metadata["render_fps"])
|
||||
pygame.display.flip()
|
||||
elif mode in {"rgb_array", "rgb_array_list"}:
|
||||
elif self.render_mode == "rgb_array":
|
||||
return np.transpose(
|
||||
np.array(pygame.surfarray.pixels3d(self.surf)), axes=(1, 0, 2)
|
||||
)[:, -VIEWPORT_W:]
|
||||
|
@@ -10,7 +10,6 @@ from gym import spaces
|
||||
from gym.envs.box2d.car_dynamics import Car
|
||||
from gym.error import DependencyNotInstalled, InvalidAction
|
||||
from gym.utils import EzPickle
|
||||
from gym.utils.renderer import Renderer
|
||||
|
||||
try:
|
||||
import Box2D
|
||||
@@ -184,8 +183,6 @@ class CarRacing(gym.Env, EzPickle):
|
||||
metadata = {
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array_list",
|
||||
"state_pixels_list",
|
||||
"rgb_array",
|
||||
"state_pixels",
|
||||
],
|
||||
@@ -247,7 +244,6 @@ class CarRacing(gym.Env, EzPickle):
|
||||
)
|
||||
|
||||
self.render_mode = render_mode
|
||||
self.renderer = Renderer(self.render_mode, self._render)
|
||||
|
||||
def _destroy(self):
|
||||
if not self.road:
|
||||
@@ -517,7 +513,6 @@ class CarRacing(gym.Env, EzPickle):
|
||||
)
|
||||
self.car = Car(self.world, *self.track[0][1:4])
|
||||
|
||||
self.renderer.reset()
|
||||
return self.step(None)[0], {}
|
||||
|
||||
def step(self, action: Union[np.ndarray, int]):
|
||||
@@ -563,13 +558,12 @@ class CarRacing(gym.Env, EzPickle):
|
||||
terminated = True
|
||||
step_reward = -100
|
||||
|
||||
self.renderer.render_step()
|
||||
return self.state, step_reward, terminated, truncated, {}
|
||||
|
||||
def render(self):
|
||||
return self.renderer.get_renders()
|
||||
return self._render(self.render_mode)
|
||||
|
||||
def _render(self, mode: str = "human"):
|
||||
def _render(self, mode: str):
|
||||
assert mode in self.metadata["render_modes"]
|
||||
|
||||
pygame.font.init()
|
||||
@@ -623,9 +617,9 @@ class CarRacing(gym.Env, EzPickle):
|
||||
self.screen.blit(self.surf, (0, 0))
|
||||
pygame.display.flip()
|
||||
|
||||
if mode in {"rgb_array", "rgb_array_list"}:
|
||||
if mode == "rgb_array":
|
||||
return self._create_image_array(self.surf, (VIDEO_W, VIDEO_H))
|
||||
elif mode in {"state_pixels_list", "state_pixels"}:
|
||||
elif mode == "state_pixels":
|
||||
return self._create_image_array(self.surf, (STATE_W, STATE_H))
|
||||
else:
|
||||
return self.isopen
|
||||
|
@@ -10,7 +10,6 @@ import gym
|
||||
from gym import error, spaces
|
||||
from gym.error import DependencyNotInstalled
|
||||
from gym.utils import EzPickle, colorize
|
||||
from gym.utils.renderer import Renderer
|
||||
from gym.utils.step_api_compatibility import step_api_compatibility
|
||||
|
||||
try:
|
||||
@@ -179,7 +178,7 @@ class LunarLander(gym.Env, EzPickle):
|
||||
"""
|
||||
|
||||
metadata = {
|
||||
"render_modes": ["human", "rgb_array", "rgb_array_list"],
|
||||
"render_modes": ["human", "rgb_array"],
|
||||
"render_fps": FPS,
|
||||
}
|
||||
|
||||
@@ -287,7 +286,6 @@ class LunarLander(gym.Env, EzPickle):
|
||||
self.action_space = spaces.Discrete(4)
|
||||
|
||||
self.render_mode = render_mode
|
||||
self.renderer = Renderer(self.render_mode, self._render)
|
||||
|
||||
def _destroy(self):
|
||||
if not self.moon:
|
||||
@@ -411,7 +409,6 @@ class LunarLander(gym.Env, EzPickle):
|
||||
|
||||
self.drawlist = [self.lander] + self.legs
|
||||
|
||||
self.renderer.reset()
|
||||
return self.step(np.array([0, 0]) if self.continuous else 0)[0], {}
|
||||
|
||||
def _create_particle(self, mass, x, y, ttl):
|
||||
@@ -589,14 +586,9 @@ class LunarLander(gym.Env, EzPickle):
|
||||
if not self.lander.awake:
|
||||
terminated = True
|
||||
reward = +100
|
||||
self.renderer.render_step()
|
||||
return np.array(state, dtype=np.float32), reward, terminated, False, {}
|
||||
|
||||
def render(self):
|
||||
return self.renderer.get_renders()
|
||||
|
||||
def _render(self, mode="human"):
|
||||
assert mode in self.metadata["render_modes"]
|
||||
try:
|
||||
import pygame
|
||||
from pygame import gfxdraw
|
||||
@@ -605,7 +597,7 @@ class LunarLander(gym.Env, EzPickle):
|
||||
"pygame is not installed, run `pip install gym[box2d]`"
|
||||
)
|
||||
|
||||
if self.screen is None and mode == "human":
|
||||
if self.screen is None and self.render_mode == "human":
|
||||
pygame.init()
|
||||
pygame.display.init()
|
||||
self.screen = pygame.display.set_mode((VIEWPORT_W, VIEWPORT_H))
|
||||
@@ -692,13 +684,13 @@ class LunarLander(gym.Env, EzPickle):
|
||||
|
||||
self.surf = pygame.transform.flip(self.surf, False, True)
|
||||
|
||||
if mode == "human":
|
||||
if self.render_mode == "human":
|
||||
assert self.screen is not None
|
||||
self.screen.blit(self.surf, (0, 0))
|
||||
pygame.event.pump()
|
||||
self.clock.tick(self.metadata["render_fps"])
|
||||
pygame.display.flip()
|
||||
elif mode in {"rgb_array", "rgb_array_list"}:
|
||||
elif self.render_mode == "rgb_array":
|
||||
return np.transpose(
|
||||
np.array(pygame.surfarray.pixels3d(self.surf)), axes=(1, 0, 2)
|
||||
)
|
||||
|
@@ -21,7 +21,6 @@ __author__ = "Christoph Dann <cdann@cdann.de>"
|
||||
# SOURCE:
|
||||
# https://github.com/rlpy/rlpy/blob/master/rlpy/Domains/Acrobot.py
|
||||
from gym.envs.classic_control import utils
|
||||
from gym.utils.renderer import Renderer
|
||||
|
||||
|
||||
class AcrobotEnv(core.Env):
|
||||
@@ -137,7 +136,7 @@ class AcrobotEnv(core.Env):
|
||||
"""
|
||||
|
||||
metadata = {
|
||||
"render_modes": ["human", "rgb_array", "rgb_array_list"],
|
||||
"render_modes": ["human", "rgb_array"],
|
||||
"render_fps": 15,
|
||||
}
|
||||
|
||||
@@ -168,7 +167,6 @@ class AcrobotEnv(core.Env):
|
||||
|
||||
def __init__(self, render_mode: Optional[str] = None):
|
||||
self.render_mode = render_mode
|
||||
self.renderer = Renderer(self.render_mode, self._render)
|
||||
self.screen = None
|
||||
self.clock = None
|
||||
self.isopen = True
|
||||
@@ -191,8 +189,6 @@ class AcrobotEnv(core.Env):
|
||||
np.float32
|
||||
)
|
||||
|
||||
self.renderer.reset()
|
||||
self.renderer.render_step()
|
||||
return self._get_ob(), {}
|
||||
|
||||
def step(self, a):
|
||||
@@ -220,7 +216,6 @@ class AcrobotEnv(core.Env):
|
||||
terminated = self._terminal()
|
||||
reward = -1.0 if not terminated else 0.0
|
||||
|
||||
self.renderer.render_step()
|
||||
return (self._get_ob(), reward, terminated, False, {})
|
||||
|
||||
def _get_ob(self):
|
||||
@@ -278,10 +273,6 @@ class AcrobotEnv(core.Env):
|
||||
return dtheta1, dtheta2, ddtheta1, ddtheta2, 0.0
|
||||
|
||||
def render(self):
|
||||
return self.renderer.get_renders()
|
||||
|
||||
def _render(self, mode="human"):
|
||||
assert mode in self.metadata["render_modes"]
|
||||
try:
|
||||
import pygame
|
||||
from pygame import gfxdraw
|
||||
@@ -292,12 +283,12 @@ class AcrobotEnv(core.Env):
|
||||
|
||||
if self.screen is None:
|
||||
pygame.init()
|
||||
if mode == "human":
|
||||
if self.render_mode == "human":
|
||||
pygame.display.init()
|
||||
self.screen = pygame.display.set_mode(
|
||||
(self.SCREEN_DIM, self.SCREEN_DIM)
|
||||
)
|
||||
else: # mode in {"rgb_array", "rgb_array_list"}
|
||||
else: # mode in "rgb_array"
|
||||
self.screen = pygame.Surface((self.SCREEN_DIM, self.SCREEN_DIM))
|
||||
if self.clock is None:
|
||||
self.clock = pygame.time.Clock()
|
||||
@@ -353,12 +344,12 @@ class AcrobotEnv(core.Env):
|
||||
surf = pygame.transform.flip(surf, False, True)
|
||||
self.screen.blit(surf, (0, 0))
|
||||
|
||||
if mode == "human":
|
||||
if self.render_mode == "human":
|
||||
pygame.event.pump()
|
||||
self.clock.tick(self.metadata["render_fps"])
|
||||
pygame.display.flip()
|
||||
|
||||
elif mode in {"rgb_array", "rgb_array_list"}:
|
||||
elif self.render_mode == "rgb_array":
|
||||
return np.transpose(
|
||||
np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)
|
||||
)
|
||||
|
@@ -12,7 +12,6 @@ import gym
|
||||
from gym import logger, spaces
|
||||
from gym.envs.classic_control import utils
|
||||
from gym.error import DependencyNotInstalled
|
||||
from gym.utils.renderer import Renderer
|
||||
|
||||
|
||||
class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
||||
@@ -83,7 +82,7 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
||||
"""
|
||||
|
||||
metadata = {
|
||||
"render_modes": ["human", "rgb_array", "rgb_array_list"],
|
||||
"render_modes": ["human", "rgb_array"],
|
||||
"render_fps": 50,
|
||||
}
|
||||
|
||||
@@ -118,7 +117,6 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
||||
self.observation_space = spaces.Box(-high, high, dtype=np.float32)
|
||||
|
||||
self.render_mode = render_mode
|
||||
self.renderer = Renderer(self.render_mode, self._render)
|
||||
|
||||
self.screen_width = 600
|
||||
self.screen_height = 400
|
||||
@@ -185,7 +183,6 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
||||
self.steps_beyond_terminated += 1
|
||||
reward = 0.0
|
||||
|
||||
self.renderer.render_step()
|
||||
return np.array(self.state, dtype=np.float32), reward, terminated, False, {}
|
||||
|
||||
def reset(
|
||||
@@ -202,15 +199,9 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
||||
) # default high
|
||||
self.state = self.np_random.uniform(low=low, high=high, size=(4,))
|
||||
self.steps_beyond_terminated = None
|
||||
self.renderer.reset()
|
||||
self.renderer.render_step()
|
||||
return np.array(self.state, dtype=np.float32), {}
|
||||
|
||||
def render(self):
|
||||
return self.renderer.get_renders()
|
||||
|
||||
def _render(self, mode="human"):
|
||||
assert mode in self.metadata["render_modes"]
|
||||
try:
|
||||
import pygame
|
||||
from pygame import gfxdraw
|
||||
@@ -221,12 +212,12 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
||||
|
||||
if self.screen is None:
|
||||
pygame.init()
|
||||
if mode == "human":
|
||||
if self.render_mode == "human":
|
||||
pygame.display.init()
|
||||
self.screen = pygame.display.set_mode(
|
||||
(self.screen_width, self.screen_height)
|
||||
)
|
||||
else: # mode in {"rgb_array", "rgb_array_list"}
|
||||
else: # mode == "rgb_array"
|
||||
self.screen = pygame.Surface((self.screen_width, self.screen_height))
|
||||
if self.clock is None:
|
||||
self.clock = pygame.time.Clock()
|
||||
@@ -289,12 +280,12 @@ class CartPoleEnv(gym.Env[np.ndarray, Union[int, np.ndarray]]):
|
||||
|
||||
self.surf = pygame.transform.flip(self.surf, False, True)
|
||||
self.screen.blit(self.surf, (0, 0))
|
||||
if mode == "human":
|
||||
if self.render_mode == "human":
|
||||
pygame.event.pump()
|
||||
self.clock.tick(self.metadata["render_fps"])
|
||||
pygame.display.flip()
|
||||
|
||||
elif mode in {"rgb_array", "rgb_array_list"}:
|
||||
elif self.render_mode == "rgb_array":
|
||||
return np.transpose(
|
||||
np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)
|
||||
)
|
||||
|
@@ -22,7 +22,6 @@ import gym
|
||||
from gym import spaces
|
||||
from gym.envs.classic_control import utils
|
||||
from gym.error import DependencyNotInstalled
|
||||
from gym.utils.renderer import Renderer
|
||||
|
||||
|
||||
class Continuous_MountainCarEnv(gym.Env):
|
||||
@@ -102,7 +101,7 @@ class Continuous_MountainCarEnv(gym.Env):
|
||||
"""
|
||||
|
||||
metadata = {
|
||||
"render_modes": ["human", "rgb_array", "rgb_array_list"],
|
||||
"render_modes": ["human", "rgb_array"],
|
||||
"render_fps": 30,
|
||||
}
|
||||
|
||||
@@ -126,7 +125,6 @@ class Continuous_MountainCarEnv(gym.Env):
|
||||
)
|
||||
|
||||
self.render_mode = render_mode
|
||||
self.renderer = Renderer(self.render_mode, self._render)
|
||||
|
||||
self.screen_width = 600
|
||||
self.screen_height = 400
|
||||
@@ -171,7 +169,6 @@ class Continuous_MountainCarEnv(gym.Env):
|
||||
reward -= math.pow(action[0], 2) * 0.1
|
||||
|
||||
self.state = np.array([position, velocity], dtype=np.float32)
|
||||
self.renderer.render_step()
|
||||
return self.state, reward, terminated, False, {}
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
@@ -180,19 +177,12 @@ class Continuous_MountainCarEnv(gym.Env):
|
||||
# state/observations.
|
||||
low, high = utils.maybe_parse_reset_bounds(options, -0.6, -0.4)
|
||||
self.state = np.array([self.np_random.uniform(low=low, high=high), 0])
|
||||
self.renderer.reset()
|
||||
self.renderer.render_step()
|
||||
return np.array(self.state, dtype=np.float32), {}
|
||||
|
||||
def _height(self, xs):
|
||||
return np.sin(3 * xs) * 0.45 + 0.55
|
||||
|
||||
def render(self):
|
||||
return self.renderer.get_renders()
|
||||
|
||||
def _render(self, mode="human"):
|
||||
assert mode in self.metadata["render_modes"]
|
||||
|
||||
try:
|
||||
import pygame
|
||||
from pygame import gfxdraw
|
||||
@@ -203,12 +193,12 @@ class Continuous_MountainCarEnv(gym.Env):
|
||||
|
||||
if self.screen is None:
|
||||
pygame.init()
|
||||
if mode == "human":
|
||||
if self.render_mode == "human":
|
||||
pygame.display.init()
|
||||
self.screen = pygame.display.set_mode(
|
||||
(self.screen_width, self.screen_height)
|
||||
)
|
||||
else: # mode in {"rgb_array", "rgb_array_list"}
|
||||
else: # mode == "rgb_array":
|
||||
self.screen = pygame.Surface((self.screen_width, self.screen_height))
|
||||
if self.clock is None:
|
||||
self.clock = pygame.time.Clock()
|
||||
@@ -277,12 +267,12 @@ class Continuous_MountainCarEnv(gym.Env):
|
||||
|
||||
self.surf = pygame.transform.flip(self.surf, False, True)
|
||||
self.screen.blit(self.surf, (0, 0))
|
||||
if mode == "human":
|
||||
if self.render_mode == "human":
|
||||
pygame.event.pump()
|
||||
self.clock.tick(self.metadata["render_fps"])
|
||||
pygame.display.flip()
|
||||
|
||||
elif mode in {"rgb_array", "rgb_array_list"}:
|
||||
elif self.render_mode == "rgb_array":
|
||||
return np.transpose(
|
||||
np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)
|
||||
)
|
||||
|
@@ -11,7 +11,6 @@ import gym
|
||||
from gym import spaces
|
||||
from gym.envs.classic_control import utils
|
||||
from gym.error import DependencyNotInstalled
|
||||
from gym.utils.renderer import Renderer
|
||||
|
||||
|
||||
class MountainCarEnv(gym.Env):
|
||||
@@ -97,7 +96,7 @@ class MountainCarEnv(gym.Env):
|
||||
"""
|
||||
|
||||
metadata = {
|
||||
"render_modes": ["human", "rgb_array", "rgb_array_list"],
|
||||
"render_modes": ["human", "rgb_array"],
|
||||
"render_fps": 30,
|
||||
}
|
||||
|
||||
@@ -115,7 +114,6 @@ class MountainCarEnv(gym.Env):
|
||||
self.high = np.array([self.max_position, self.max_speed], dtype=np.float32)
|
||||
|
||||
self.render_mode = render_mode
|
||||
self.renderer = Renderer(self.render_mode, self._render)
|
||||
|
||||
self.screen_width = 600
|
||||
self.screen_height = 400
|
||||
@@ -145,7 +143,6 @@ class MountainCarEnv(gym.Env):
|
||||
reward = -1.0
|
||||
|
||||
self.state = (position, velocity)
|
||||
self.renderer.render_step()
|
||||
return np.array(self.state, dtype=np.float32), reward, terminated, False, {}
|
||||
|
||||
def reset(
|
||||
@@ -159,18 +156,12 @@ class MountainCarEnv(gym.Env):
|
||||
# state/observations.
|
||||
low, high = utils.maybe_parse_reset_bounds(options, -0.6, -0.4)
|
||||
self.state = np.array([self.np_random.uniform(low=low, high=high), 0])
|
||||
self.renderer.reset()
|
||||
self.renderer.render_step()
|
||||
return np.array(self.state, dtype=np.float32), {}
|
||||
|
||||
def _height(self, xs):
|
||||
return np.sin(3 * xs) * 0.45 + 0.55
|
||||
|
||||
def render(self):
|
||||
return self.renderer.get_renders()
|
||||
|
||||
def _render(self, mode="human"):
|
||||
assert mode in self.metadata["render_modes"]
|
||||
try:
|
||||
import pygame
|
||||
from pygame import gfxdraw
|
||||
@@ -181,12 +172,12 @@ class MountainCarEnv(gym.Env):
|
||||
|
||||
if self.screen is None:
|
||||
pygame.init()
|
||||
if mode == "human":
|
||||
if self.render_mode == "human":
|
||||
pygame.display.init()
|
||||
self.screen = pygame.display.set_mode(
|
||||
(self.screen_width, self.screen_height)
|
||||
)
|
||||
else: # mode in {"rgb_array", "rgb_array_list"}
|
||||
else: # mode in "rgb_array"
|
||||
self.screen = pygame.Surface((self.screen_width, self.screen_height))
|
||||
if self.clock is None:
|
||||
self.clock = pygame.time.Clock()
|
||||
@@ -255,12 +246,12 @@ class MountainCarEnv(gym.Env):
|
||||
|
||||
self.surf = pygame.transform.flip(self.surf, False, True)
|
||||
self.screen.blit(self.surf, (0, 0))
|
||||
if mode == "human":
|
||||
if self.render_mode == "human":
|
||||
pygame.event.pump()
|
||||
self.clock.tick(self.metadata["render_fps"])
|
||||
pygame.display.flip()
|
||||
|
||||
elif mode in {"rgb_array", "rgb_array_list"}:
|
||||
elif self.render_mode == "rgb_array":
|
||||
return np.transpose(
|
||||
np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)
|
||||
)
|
||||
|
@@ -9,7 +9,6 @@ import gym
|
||||
from gym import spaces
|
||||
from gym.envs.classic_control import utils
|
||||
from gym.error import DependencyNotInstalled
|
||||
from gym.utils.renderer import Renderer
|
||||
|
||||
DEFAULT_X = np.pi
|
||||
DEFAULT_Y = 1.0
|
||||
@@ -89,7 +88,7 @@ class PendulumEnv(gym.Env):
|
||||
"""
|
||||
|
||||
metadata = {
|
||||
"render_modes": ["human", "rgb_array", "rgb_array_list"],
|
||||
"render_modes": ["human", "rgb_array"],
|
||||
"render_fps": 30,
|
||||
}
|
||||
|
||||
@@ -102,7 +101,6 @@ class PendulumEnv(gym.Env):
|
||||
self.l = 1.0
|
||||
|
||||
self.render_mode = render_mode
|
||||
self.renderer = Renderer(self.render_mode, self._render)
|
||||
|
||||
self.screen_dim = 500
|
||||
self.screen = None
|
||||
@@ -135,7 +133,6 @@ class PendulumEnv(gym.Env):
|
||||
newth = th + newthdot * dt
|
||||
|
||||
self.state = np.array([newth, newthdot])
|
||||
self.renderer.render_step()
|
||||
return self._get_obs(), -costs, False, False, {}
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
@@ -154,8 +151,6 @@ class PendulumEnv(gym.Env):
|
||||
self.state = self.np_random.uniform(low=low, high=high)
|
||||
self.last_u = None
|
||||
|
||||
self.renderer.reset()
|
||||
self.renderer.render_step()
|
||||
return self._get_obs(), {}
|
||||
|
||||
def _get_obs(self):
|
||||
@@ -163,10 +158,6 @@ class PendulumEnv(gym.Env):
|
||||
return np.array([np.cos(theta), np.sin(theta), thetadot], dtype=np.float32)
|
||||
|
||||
def render(self):
|
||||
return self.renderer.get_renders()
|
||||
|
||||
def _render(self, mode="human"):
|
||||
assert mode in self.metadata["render_modes"]
|
||||
try:
|
||||
import pygame
|
||||
from pygame import gfxdraw
|
||||
@@ -177,12 +168,12 @@ class PendulumEnv(gym.Env):
|
||||
|
||||
if self.screen is None:
|
||||
pygame.init()
|
||||
if mode == "human":
|
||||
if self.render_mode == "human":
|
||||
pygame.display.init()
|
||||
self.screen = pygame.display.set_mode(
|
||||
(self.screen_dim, self.screen_dim)
|
||||
)
|
||||
else: # mode in {"rgb_array", "rgb_array_list"}
|
||||
else: # mode in "rgb_array"
|
||||
self.screen = pygame.Surface((self.screen_dim, self.screen_dim))
|
||||
if self.clock is None:
|
||||
self.clock = pygame.time.Clock()
|
||||
@@ -244,12 +235,12 @@ class PendulumEnv(gym.Env):
|
||||
|
||||
self.surf = pygame.transform.flip(self.surf, False, True)
|
||||
self.screen.blit(self.surf, (0, 0))
|
||||
if mode == "human":
|
||||
if self.render_mode == "human":
|
||||
pygame.event.pump()
|
||||
self.clock.tick(self.metadata["render_fps"])
|
||||
pygame.display.flip()
|
||||
|
||||
else: # mode == "rgb_array_list":
|
||||
else: # mode == "rgb_array":
|
||||
return np.transpose(
|
||||
np.array(pygame.surfarray.pixels3d(self.screen)), axes=(1, 0, 2)
|
||||
)
|
||||
|
@@ -10,9 +10,7 @@ class AntEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 20,
|
||||
}
|
||||
@@ -31,8 +29,6 @@ class AntEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
self.do_simulation(a, self.frame_skip)
|
||||
xposafter = self.get_body_com("torso")[0]
|
||||
|
||||
self.renderer.render_step()
|
||||
|
||||
forward_reward = (xposafter - xposbefore) / self.dt
|
||||
ctrl_cost = 0.5 * np.square(a).sum()
|
||||
contact_cost = (
|
||||
|
@@ -14,9 +14,7 @@ class AntEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 20,
|
||||
}
|
||||
@@ -130,8 +128,6 @@ class AntEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
rewards = forward_reward + healthy_reward
|
||||
costs = ctrl_cost + contact_cost
|
||||
|
||||
self.renderer.render_step()
|
||||
|
||||
reward = rewards - costs
|
||||
terminated = self.terminated
|
||||
observation = self._get_obs()
|
||||
|
@@ -176,9 +176,7 @@ class AntEnv(MujocoEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 20,
|
||||
}
|
||||
@@ -315,7 +313,6 @@ class AntEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
reward = rewards - costs
|
||||
|
||||
self.renderer.render_step()
|
||||
return observation, reward, terminated, False, info
|
||||
|
||||
def _get_obs(self):
|
||||
|
@@ -10,9 +10,7 @@ class HalfCheetahEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 20,
|
||||
}
|
||||
@@ -29,8 +27,6 @@ class HalfCheetahEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
self.do_simulation(action, self.frame_skip)
|
||||
xposafter = self.sim.data.qpos[0]
|
||||
|
||||
self.renderer.render_step()
|
||||
|
||||
ob = self._get_obs()
|
||||
reward_ctrl = -0.1 * np.square(action).sum()
|
||||
reward_run = (xposafter - xposbefore) / self.dt
|
||||
|
@@ -16,9 +16,7 @@ class HalfCheetahEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 20,
|
||||
}
|
||||
@@ -79,8 +77,6 @@ class HalfCheetahEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
|
||||
forward_reward = self._forward_reward_weight * x_velocity
|
||||
|
||||
self.renderer.render_step()
|
||||
|
||||
observation = self._get_obs()
|
||||
reward = forward_reward - ctrl_cost
|
||||
terminated = False
|
||||
|
@@ -136,9 +136,7 @@ class HalfCheetahEnv(MujocoEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 20,
|
||||
}
|
||||
@@ -207,7 +205,6 @@ class HalfCheetahEnv(MujocoEnv, utils.EzPickle):
|
||||
"reward_ctrl": -ctrl_cost,
|
||||
}
|
||||
|
||||
self.renderer.render_step()
|
||||
return observation, reward, terminated, False, info
|
||||
|
||||
def _get_obs(self):
|
||||
|
@@ -10,9 +10,7 @@ class HopperEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 125,
|
||||
}
|
||||
@@ -29,8 +27,6 @@ class HopperEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
self.do_simulation(a, self.frame_skip)
|
||||
posafter, height, ang = self.sim.data.qpos[0:3]
|
||||
|
||||
self.renderer.render_step()
|
||||
|
||||
alive_bonus = 1.0
|
||||
reward = (posafter - posbefore) / self.dt
|
||||
reward += alive_bonus
|
||||
|
@@ -19,9 +19,7 @@ class HopperEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 125,
|
||||
}
|
||||
@@ -142,8 +140,6 @@ class HopperEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
rewards = forward_reward + healthy_reward
|
||||
costs = ctrl_cost
|
||||
|
||||
self.renderer.render_step()
|
||||
|
||||
observation = self._get_obs()
|
||||
reward = rewards - costs
|
||||
terminated = self.terminated
|
||||
|
@@ -142,9 +142,7 @@ class HopperEnv(MujocoEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 125,
|
||||
}
|
||||
@@ -271,7 +269,6 @@ class HopperEnv(MujocoEnv, utils.EzPickle):
|
||||
"x_velocity": x_velocity,
|
||||
}
|
||||
|
||||
self.renderer.render_step()
|
||||
return observation, reward, terminated, False, info
|
||||
|
||||
def reset_model(self):
|
||||
|
@@ -16,9 +16,7 @@ class HumanoidEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 67,
|
||||
}
|
||||
@@ -50,8 +48,6 @@ class HumanoidEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
self.do_simulation(a, self.frame_skip)
|
||||
pos_after = mass_center(self.model, self.sim)
|
||||
|
||||
self.renderer.render_step()
|
||||
|
||||
alive_bonus = 5.0
|
||||
data = self.sim.data
|
||||
lin_vel_cost = 1.25 * (pos_after - pos_before) / self.dt
|
||||
|
@@ -23,9 +23,7 @@ class HumanoidEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 67,
|
||||
}
|
||||
@@ -157,8 +155,6 @@ class HumanoidEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
rewards = forward_reward + healthy_reward
|
||||
costs = ctrl_cost + contact_cost
|
||||
|
||||
self.renderer.render_step()
|
||||
|
||||
observation = self._get_obs()
|
||||
reward = rewards - costs
|
||||
terminated = self.terminated
|
||||
|
@@ -216,9 +216,7 @@ class HumanoidEnv(MujocoEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 67,
|
||||
}
|
||||
@@ -348,7 +346,6 @@ class HumanoidEnv(MujocoEnv, utils.EzPickle):
|
||||
"forward_reward": forward_reward,
|
||||
}
|
||||
|
||||
self.renderer.render_step()
|
||||
return observation, reward, terminated, False, info
|
||||
|
||||
def reset_model(self):
|
||||
|
@@ -10,9 +10,7 @@ class HumanoidStandupEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 67,
|
||||
}
|
||||
@@ -54,8 +52,6 @@ class HumanoidStandupEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
quad_impact_cost = min(quad_impact_cost, 10)
|
||||
reward = uph_cost - quad_ctrl_cost - quad_impact_cost + 1
|
||||
|
||||
self.renderer.render_step()
|
||||
|
||||
return (
|
||||
self._get_obs(),
|
||||
reward,
|
||||
|
@@ -182,9 +182,7 @@ class HumanoidStandupEnv(MujocoEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 67,
|
||||
}
|
||||
@@ -226,8 +224,6 @@ class HumanoidStandupEnv(MujocoEnv, utils.EzPickle):
|
||||
quad_impact_cost = min(quad_impact_cost, 10)
|
||||
reward = uph_cost - quad_ctrl_cost - quad_impact_cost + 1
|
||||
|
||||
self.renderer.render_step()
|
||||
|
||||
return (
|
||||
self._get_obs(),
|
||||
reward,
|
||||
|
@@ -10,9 +10,7 @@ class InvertedDoublePendulumEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 20,
|
||||
}
|
||||
@@ -31,8 +29,6 @@ class InvertedDoublePendulumEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
def step(self, action):
|
||||
self.do_simulation(action, self.frame_skip)
|
||||
|
||||
self.renderer.render_step()
|
||||
|
||||
ob = self._get_obs()
|
||||
x, _, y = self.sim.data.site_xpos[0]
|
||||
dist_penalty = 0.01 * x**2 + (y - 2) ** 2
|
||||
|
@@ -116,9 +116,7 @@ class InvertedDoublePendulumEnv(MujocoEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 20,
|
||||
}
|
||||
@@ -144,7 +142,6 @@ class InvertedDoublePendulumEnv(MujocoEnv, utils.EzPickle):
|
||||
alive_bonus = 10
|
||||
r = alive_bonus - dist_penalty - vel_penalty
|
||||
terminated = bool(y <= 1)
|
||||
self.renderer.render_step()
|
||||
return ob, r, terminated, False, {}
|
||||
|
||||
def _get_obs(self):
|
||||
|
@@ -10,9 +10,7 @@ class InvertedPendulumEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 25,
|
||||
}
|
||||
@@ -32,8 +30,6 @@ class InvertedPendulumEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
reward = 1.0
|
||||
self.do_simulation(a, self.frame_skip)
|
||||
|
||||
self.renderer.render_step()
|
||||
|
||||
ob = self._get_obs()
|
||||
terminated = bool(not np.isfinite(ob).all() or (np.abs(ob[1]) > 0.2))
|
||||
return ob, reward, terminated, False, {}
|
||||
|
@@ -87,9 +87,7 @@ class InvertedPendulumEnv(MujocoEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 25,
|
||||
}
|
||||
@@ -110,7 +108,6 @@ class InvertedPendulumEnv(MujocoEnv, utils.EzPickle):
|
||||
self.do_simulation(a, self.frame_skip)
|
||||
ob = self._get_obs()
|
||||
terminated = bool(not np.isfinite(ob).all() or (np.abs(ob[1]) > 0.2))
|
||||
self.renderer.render_step()
|
||||
return ob, reward, terminated, False, {}
|
||||
|
||||
def reset_model(self):
|
||||
|
@@ -1,4 +1,3 @@
|
||||
from functools import partial
|
||||
from os import path
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -7,7 +6,6 @@ import numpy as np
|
||||
import gym
|
||||
from gym import error, logger, spaces
|
||||
from gym.spaces import Space
|
||||
from gym.utils.renderer import Renderer
|
||||
|
||||
MUJOCO_PY_NOT_INSTALLED = False
|
||||
MUJOCO_NOT_INSTALLED = False
|
||||
@@ -64,9 +62,7 @@ class BaseMujocoEnv(gym.Env):
|
||||
assert self.metadata["render_modes"] == [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
], self.metadata["render_modes"]
|
||||
assert (
|
||||
int(np.round(1.0 / self.dt)) == self.metadata["render_fps"]
|
||||
@@ -76,12 +72,8 @@ class BaseMujocoEnv(gym.Env):
|
||||
self._set_action_space()
|
||||
|
||||
self.render_mode = render_mode
|
||||
render_frame = partial(
|
||||
self._render,
|
||||
camera_name=camera_name,
|
||||
camera_id=camera_id,
|
||||
)
|
||||
self.renderer = Renderer(self.render_mode, render_frame)
|
||||
self.camera_name = camera_name
|
||||
self.camera_id = camera_id
|
||||
|
||||
def _set_action_space(self):
|
||||
bounds = self.model.actuator_ctrlrange.copy().astype(np.float32)
|
||||
@@ -123,12 +115,7 @@ class BaseMujocoEnv(gym.Env):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _render(
|
||||
self,
|
||||
mode: str = "human",
|
||||
camera_id: Optional[int] = None,
|
||||
camera_name: Optional[str] = None,
|
||||
):
|
||||
def render(self):
|
||||
"""
|
||||
Render a frame from the MuJoCo simulation as specified by the render_mode.
|
||||
"""
|
||||
@@ -147,8 +134,6 @@ class BaseMujocoEnv(gym.Env):
|
||||
self._reset_simulation()
|
||||
|
||||
ob = self.reset_model()
|
||||
self.renderer.reset()
|
||||
self.renderer.render_step()
|
||||
return ob, {}
|
||||
|
||||
def set_state(self, qpos, qvel):
|
||||
@@ -170,9 +155,6 @@ class BaseMujocoEnv(gym.Env):
|
||||
raise ValueError("Action dimension mismatch")
|
||||
self._step_mujoco_simulation(ctrl, n_frames)
|
||||
|
||||
def render(self):
|
||||
return self.renderer.get_renders()
|
||||
|
||||
def close(self):
|
||||
if self.viewer is not None:
|
||||
self.viewer = None
|
||||
@@ -244,20 +226,10 @@ class MuJocoPyEnv(BaseMujocoEnv):
|
||||
for _ in range(n_frames):
|
||||
self.sim.step()
|
||||
|
||||
def _render(
|
||||
self,
|
||||
mode: str = "human",
|
||||
camera_id: Optional[int] = None,
|
||||
camera_name: Optional[str] = None,
|
||||
):
|
||||
def render(self):
|
||||
width, height = self.width, self.height
|
||||
assert mode in self.metadata["render_modes"]
|
||||
if mode in {
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
}:
|
||||
camera_name, camera_id = self.camera_name, self.camera_id
|
||||
if self.render_mode in {"rgb_array", "depth_array"}:
|
||||
if camera_id is not None and camera_name is not None:
|
||||
raise ValueError(
|
||||
"Both `camera_id` and `camera_name` cannot be"
|
||||
@@ -272,20 +244,26 @@ class MuJocoPyEnv(BaseMujocoEnv):
|
||||
if camera_name in self.model._camera_name2id:
|
||||
camera_id = self.model.camera_name2id(camera_name)
|
||||
|
||||
self._get_viewer(mode).render(width, height, camera_id=camera_id)
|
||||
self._get_viewer(self.render_mode).render(
|
||||
width, height, camera_id=camera_id
|
||||
)
|
||||
|
||||
if mode in {"rgb_array", "rgb_array_list"}:
|
||||
data = self._get_viewer(mode).read_pixels(width, height, depth=False)
|
||||
if self.render_mode == "rgb_array":
|
||||
data = self._get_viewer(self.render_mode).read_pixels(
|
||||
width, height, depth=False
|
||||
)
|
||||
# original image is upside-down, so flip it
|
||||
return data[::-1, :, :]
|
||||
elif mode in {"depth_array_list", "depth_array"}:
|
||||
self._get_viewer(mode).render(width, height)
|
||||
elif self.render_mode == "depth_array":
|
||||
self._get_viewer(self.render_mode).render(width, height)
|
||||
# Extract depth part of the read_pixels() tuple
|
||||
data = self._get_viewer(mode).read_pixels(width, height, depth=True)[1]
|
||||
data = self._get_viewer(self.render_mode).read_pixels(
|
||||
width, height, depth=True
|
||||
)[1]
|
||||
# original image is upside-down, so flip it
|
||||
return data[::-1, :]
|
||||
elif mode == "human":
|
||||
self._get_viewer(mode).render()
|
||||
elif self.render_mode == "human":
|
||||
self._get_viewer(self.render_mode).render()
|
||||
|
||||
def _get_viewer(
|
||||
self, mode
|
||||
@@ -295,12 +273,7 @@ class MuJocoPyEnv(BaseMujocoEnv):
|
||||
if mode == "human":
|
||||
self.viewer = mujoco_py.MjViewer(self.sim)
|
||||
|
||||
elif mode in {
|
||||
"rgb_array",
|
||||
"depth_array",
|
||||
"rgb_array_list",
|
||||
"depth_array_list",
|
||||
}:
|
||||
elif mode in {"rgb_array", "depth_array"}:
|
||||
self.viewer = mujoco_py.MjRenderContextOffscreen(self.sim, -1)
|
||||
else:
|
||||
raise AttributeError(
|
||||
@@ -373,20 +346,14 @@ class MujocoEnv(BaseMujocoEnv):
|
||||
# See https://github.com/openai/gym/issues/1541
|
||||
mujoco.mj_rnePostConstraint(self.model, self.data)
|
||||
|
||||
def _render(
|
||||
self,
|
||||
mode: str = "human",
|
||||
camera_id: Optional[int] = None,
|
||||
camera_name: Optional[str] = None,
|
||||
):
|
||||
assert mode in self.metadata["render_modes"]
|
||||
|
||||
if mode in {
|
||||
def render(self):
|
||||
if self.render_mode in {
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
}:
|
||||
camera_id = self.camera_id
|
||||
camera_name = self.camera_name
|
||||
|
||||
if camera_id is not None and camera_name is not None:
|
||||
raise ValueError(
|
||||
"Both `camera_id` and `camera_name` cannot be"
|
||||
@@ -404,20 +371,20 @@ class MujocoEnv(BaseMujocoEnv):
|
||||
camera_name,
|
||||
)
|
||||
|
||||
self._get_viewer(mode).render(camera_id=camera_id)
|
||||
self._get_viewer(self.render_mode).render(camera_id=camera_id)
|
||||
|
||||
if mode in {"rgb_array", "rgb_array_list"}:
|
||||
data = self._get_viewer(mode).read_pixels(depth=False)
|
||||
if self.render_mode == "rgb_array":
|
||||
data = self._get_viewer(self.render_mode).read_pixels(depth=False)
|
||||
# original image is upside-down, so flip it
|
||||
return data[::-1, :, :]
|
||||
elif mode in {"depth_array", "depth_array_list"}:
|
||||
self._get_viewer(mode).render()
|
||||
elif self.render_mode == "depth_array":
|
||||
self._get_viewer(self.render_mode).render()
|
||||
# Extract depth part of the read_pixels() tuple
|
||||
data = self._get_viewer(mode).read_pixels(depth=True)[1]
|
||||
data = self._get_viewer(self.render_mode).read_pixels(depth=True)[1]
|
||||
# original image is upside-down, so flip it
|
||||
return data[::-1, :]
|
||||
elif mode == "human":
|
||||
self._get_viewer(mode).render()
|
||||
elif self.render_mode == "human":
|
||||
self._get_viewer(self.render_mode).render()
|
||||
|
||||
def close(self):
|
||||
if self.viewer is not None:
|
||||
@@ -433,12 +400,7 @@ class MujocoEnv(BaseMujocoEnv):
|
||||
from gym.envs.mujoco import Viewer
|
||||
|
||||
self.viewer = Viewer(self.model, self.data)
|
||||
elif mode in {
|
||||
"rgb_array",
|
||||
"depth_array",
|
||||
"rgb_array_list",
|
||||
"depth_array_list",
|
||||
}:
|
||||
elif mode in {"rgb_array", "depth_array"}:
|
||||
from gym.envs.mujoco import RenderContextOffscreen
|
||||
|
||||
self.viewer = RenderContextOffscreen(self.model, self.data)
|
||||
|
@@ -10,9 +10,7 @@ class PusherEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 20,
|
||||
}
|
||||
@@ -35,8 +33,6 @@ class PusherEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
|
||||
self.do_simulation(a, self.frame_skip)
|
||||
|
||||
self.renderer.render_step()
|
||||
|
||||
ob = self._get_obs()
|
||||
return (
|
||||
ob,
|
||||
|
@@ -132,9 +132,7 @@ class PusherEnv(MujocoEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 20,
|
||||
}
|
||||
@@ -157,7 +155,6 @@ class PusherEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
self.do_simulation(a, self.frame_skip)
|
||||
ob = self._get_obs()
|
||||
self.renderer.render_step()
|
||||
return (
|
||||
ob,
|
||||
reward,
|
||||
|
@@ -10,9 +10,7 @@ class ReacherEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 50,
|
||||
}
|
||||
@@ -31,8 +29,6 @@ class ReacherEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
reward = reward_dist + reward_ctrl
|
||||
self.do_simulation(a, self.frame_skip)
|
||||
|
||||
self.renderer.render_step()
|
||||
|
||||
ob = self._get_obs()
|
||||
return (
|
||||
ob,
|
||||
|
@@ -122,9 +122,7 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 50,
|
||||
}
|
||||
@@ -143,7 +141,6 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
|
||||
reward = reward_dist + reward_ctrl
|
||||
self.do_simulation(a, self.frame_skip)
|
||||
ob = self._get_obs()
|
||||
self.renderer.render_step()
|
||||
return (
|
||||
ob,
|
||||
reward,
|
||||
|
@@ -10,9 +10,7 @@ class SwimmerEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 25,
|
||||
}
|
||||
@@ -30,8 +28,6 @@ class SwimmerEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
self.do_simulation(a, self.frame_skip)
|
||||
xposafter = self.sim.data.qpos[0]
|
||||
|
||||
self.renderer.render_step()
|
||||
|
||||
reward_fwd = (xposafter - xposbefore) / self.dt
|
||||
reward_ctrl = -ctrl_cost_coeff * np.square(a).sum()
|
||||
reward = reward_fwd + reward_ctrl
|
||||
|
@@ -14,9 +14,7 @@ class SwimmerEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 25,
|
||||
}
|
||||
@@ -71,8 +69,6 @@ class SwimmerEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
self.do_simulation(action, self.frame_skip)
|
||||
xy_position_after = self.sim.data.qpos[0:2].copy()
|
||||
|
||||
self.renderer.render_step()
|
||||
|
||||
xy_velocity = (xy_position_after - xy_position_before) / self.dt
|
||||
x_velocity, y_velocity = xy_velocity
|
||||
|
||||
|
@@ -128,9 +128,7 @@ class SwimmerEnv(MujocoEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 25,
|
||||
}
|
||||
@@ -201,7 +199,6 @@ class SwimmerEnv(MujocoEnv, utils.EzPickle):
|
||||
"forward_reward": forward_reward,
|
||||
}
|
||||
|
||||
self.renderer.render_step()
|
||||
return observation, reward, False, False, info
|
||||
|
||||
def _get_obs(self):
|
||||
|
@@ -10,9 +10,7 @@ class Walker2dEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 125,
|
||||
}
|
||||
@@ -29,8 +27,6 @@ class Walker2dEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
self.do_simulation(a, self.frame_skip)
|
||||
posafter, height, ang = self.sim.data.qpos[0:3]
|
||||
|
||||
self.renderer.render_step()
|
||||
|
||||
alive_bonus = 1.0
|
||||
reward = (posafter - posbefore) / self.dt
|
||||
reward += alive_bonus
|
||||
|
@@ -17,9 +17,7 @@ class Walker2dEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 125,
|
||||
}
|
||||
@@ -124,8 +122,6 @@ class Walker2dEnv(MuJocoPyEnv, utils.EzPickle):
|
||||
x_position_after = self.sim.data.qpos[0]
|
||||
x_velocity = (x_position_after - x_position_before) / self.dt
|
||||
|
||||
self.renderer.render_step()
|
||||
|
||||
ctrl_cost = self.control_cost(action)
|
||||
forward_reward = self._forward_reward_weight * x_velocity
|
||||
healthy_reward = self.healthy_reward
|
||||
|
@@ -147,9 +147,7 @@ class Walker2dEnv(MujocoEnv, utils.EzPickle):
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"rgb_array_list",
|
||||
"depth_array",
|
||||
"depth_array_list",
|
||||
],
|
||||
"render_fps": 125,
|
||||
}
|
||||
@@ -268,7 +266,6 @@ class Walker2dEnv(MujocoEnv, utils.EzPickle):
|
||||
"x_velocity": x_velocity,
|
||||
}
|
||||
|
||||
self.renderer.render_step()
|
||||
return observation, reward, terminated, False, info
|
||||
|
||||
def reset_model(self):
|
||||
|
@@ -25,6 +25,7 @@ from gym.wrappers import (
|
||||
AutoResetWrapper,
|
||||
HumanRendering,
|
||||
OrderEnforcing,
|
||||
RenderCollection,
|
||||
StepAPICompatibility,
|
||||
TimeLimit,
|
||||
)
|
||||
@@ -581,6 +582,7 @@ def make(
|
||||
|
||||
mode = _kwargs.get("render_mode")
|
||||
apply_human_rendering = False
|
||||
apply_render_collection = False
|
||||
|
||||
# If we have access to metadata we check that "render_mode" is valid and see if the HumanRendering wrapper needs to be applied
|
||||
if mode is not None and hasattr(env_creator, "metadata"):
|
||||
@@ -610,6 +612,13 @@ def make(
|
||||
_kwargs["render_mode"] = "rgb_array"
|
||||
else:
|
||||
_kwargs["render_mode"] = "rgb_array_list"
|
||||
elif (
|
||||
mode not in render_modes
|
||||
and mode.endswith("_list")
|
||||
and mode[: -len("_list")] in render_modes
|
||||
):
|
||||
_kwargs["render_mode"] = mode[: -len("_list")]
|
||||
apply_render_collection = True
|
||||
elif mode not in render_modes:
|
||||
logger.warn(
|
||||
f"The environment is being initialised with mode ({mode}) that is not in the possible render_modes ({render_modes})."
|
||||
@@ -668,6 +677,8 @@ def make(
|
||||
# Add human rendering wrapper
|
||||
if apply_human_rendering:
|
||||
env = HumanRendering(env)
|
||||
elif apply_render_collection:
|
||||
env = RenderCollection(env)
|
||||
|
||||
return env
|
||||
|
||||
|
@@ -6,7 +6,6 @@ import numpy as np
|
||||
import gym
|
||||
from gym import spaces
|
||||
from gym.error import DependencyNotInstalled
|
||||
from gym.utils.renderer import Renderer
|
||||
|
||||
|
||||
def cmp(a, b):
|
||||
@@ -112,7 +111,7 @@ class BlackjackEnv(gym.Env):
|
||||
"""
|
||||
|
||||
metadata = {
|
||||
"render_modes": ["human", "rgb_array", "rgb_array_list"],
|
||||
"render_modes": ["human", "rgb_array"],
|
||||
"render_fps": 4,
|
||||
}
|
||||
|
||||
@@ -130,7 +129,6 @@ class BlackjackEnv(gym.Env):
|
||||
self.sab = sab
|
||||
|
||||
self.render_mode = render_mode
|
||||
self.renderer = Renderer(self.render_mode, self._render)
|
||||
|
||||
def step(self, action):
|
||||
assert self.action_space.contains(action)
|
||||
@@ -158,7 +156,6 @@ class BlackjackEnv(gym.Env):
|
||||
):
|
||||
# Natural gives extra points, but doesn't autowin. Legacy implementation
|
||||
reward = 1.5
|
||||
self.renderer.render_step()
|
||||
return self._get_obs(), reward, terminated, False, {}
|
||||
|
||||
def _get_obs(self):
|
||||
@@ -185,17 +182,9 @@ class BlackjackEnv(gym.Env):
|
||||
else:
|
||||
self.dealer_top_card_value_str = str(dealer_card_value)
|
||||
|
||||
self.renderer.reset()
|
||||
self.renderer.render_step()
|
||||
|
||||
return self._get_obs(), {}
|
||||
|
||||
def render(self):
|
||||
return self.renderer.get_renders()
|
||||
|
||||
def _render(self, mode: str = "human"):
|
||||
assert mode in self.metadata["render_modes"]
|
||||
|
||||
try:
|
||||
import pygame
|
||||
except ImportError:
|
||||
@@ -214,7 +203,7 @@ class BlackjackEnv(gym.Env):
|
||||
|
||||
if not hasattr(self, "screen"):
|
||||
pygame.init()
|
||||
if mode == "human":
|
||||
if self.render_mode == "human":
|
||||
pygame.display.init()
|
||||
self.screen = pygame.display.set_mode((screen_width, screen_height))
|
||||
else:
|
||||
@@ -296,7 +285,7 @@ class BlackjackEnv(gym.Env):
|
||||
player_sum_text_rect.bottom + spacing // 2,
|
||||
),
|
||||
)
|
||||
if mode == "human":
|
||||
if self.render_mode == "human":
|
||||
pygame.event.pump()
|
||||
pygame.display.update()
|
||||
self.clock.tick(self.metadata["render_fps"])
|
||||
|
@@ -7,7 +7,7 @@ import numpy as np
|
||||
|
||||
from gym import Env, spaces
|
||||
from gym.envs.toy_text.utils import categorical_sample
|
||||
from gym.utils.renderer import Renderer
|
||||
from gym.error import DependencyNotInstalled
|
||||
|
||||
UP = 0
|
||||
RIGHT = 1
|
||||
@@ -63,7 +63,7 @@ class CliffWalkingEnv(Env):
|
||||
"""
|
||||
|
||||
metadata = {
|
||||
"render_modes": ["human", "rgb_array", "rgb_array_list", "ansi"],
|
||||
"render_modes": ["human", "rgb_array", "ansi"],
|
||||
"render_fps": 4,
|
||||
}
|
||||
|
||||
@@ -97,7 +97,6 @@ class CliffWalkingEnv(Env):
|
||||
self.action_space = spaces.Discrete(self.nA)
|
||||
|
||||
self.render_mode = render_mode
|
||||
self.renderer = Renderer(self.render_mode, self._render)
|
||||
|
||||
# pygame utils
|
||||
self.cell_size = (60, 60)
|
||||
@@ -149,30 +148,28 @@ class CliffWalkingEnv(Env):
|
||||
p, s, r, t = transitions[i]
|
||||
self.s = s
|
||||
self.lastaction = a
|
||||
self.renderer.render_step()
|
||||
return (int(s), r, t, False, {"prob": p})
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
super().reset(seed=seed)
|
||||
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
|
||||
self.lastaction = None
|
||||
self.renderer.reset()
|
||||
self.renderer.render_step()
|
||||
|
||||
return int(self.s), {"prob": 1}
|
||||
|
||||
def render(self):
|
||||
return self.renderer.get_renders()
|
||||
|
||||
def _render(self, mode="human"):
|
||||
if mode == "ansi":
|
||||
if self.render_mode == "ansi":
|
||||
return self._render_text()
|
||||
else:
|
||||
return self._render_gui(mode)
|
||||
return self._render_gui(self.render_mode)
|
||||
|
||||
def _render_gui(self, mode):
|
||||
import pygame
|
||||
|
||||
try:
|
||||
import pygame
|
||||
except ImportError:
|
||||
raise DependencyNotInstalled(
|
||||
"pygame is not installed, run `pip install gym[toy_text]`"
|
||||
)
|
||||
if self.window_surface is None:
|
||||
pygame.init()
|
||||
|
||||
|
@@ -8,7 +8,6 @@ import numpy as np
|
||||
from gym import Env, spaces, utils
|
||||
from gym.envs.toy_text.utils import categorical_sample
|
||||
from gym.error import DependencyNotInstalled
|
||||
from gym.utils.renderer import Renderer
|
||||
|
||||
LEFT = 0
|
||||
DOWN = 1
|
||||
@@ -156,7 +155,7 @@ class FrozenLakeEnv(Env):
|
||||
"""
|
||||
|
||||
metadata = {
|
||||
"render_modes": ["human", "ansi", "rgb_array", "rgb_array_list"],
|
||||
"render_modes": ["human", "ansi", "rgb_array"],
|
||||
"render_fps": 4,
|
||||
}
|
||||
|
||||
@@ -226,7 +225,6 @@ class FrozenLakeEnv(Env):
|
||||
self.action_space = spaces.Discrete(nA)
|
||||
|
||||
self.render_mode = render_mode
|
||||
self.renderer = Renderer(self.render_mode, self._render)
|
||||
|
||||
# pygame utils
|
||||
self.window_size = (min(64 * ncol, 512), min(64 * nrow, 512))
|
||||
@@ -249,7 +247,6 @@ class FrozenLakeEnv(Env):
|
||||
p, s, r, t = transitions[i]
|
||||
self.s = s
|
||||
self.lastaction = a
|
||||
self.renderer.render_step()
|
||||
return (int(s), r, t, False, {"prob": p})
|
||||
|
||||
def reset(
|
||||
@@ -262,20 +259,13 @@ class FrozenLakeEnv(Env):
|
||||
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
|
||||
self.lastaction = None
|
||||
|
||||
self.renderer.reset()
|
||||
self.renderer.render_step()
|
||||
|
||||
return int(self.s), {"prob": 1}
|
||||
|
||||
def render(self):
|
||||
return self.renderer.get_renders()
|
||||
|
||||
def _render(self, mode="human"):
|
||||
assert mode in self.metadata["render_modes"]
|
||||
if mode == "ansi":
|
||||
if self.render_mode == "ansi":
|
||||
return self._render_text()
|
||||
elif mode in {"human", "rgb_array", "rgb_array_list"}:
|
||||
return self._render_gui(mode)
|
||||
else: # self.render_mode in {"human", "rgb_array"}:
|
||||
return self._render_gui(self.render_mode)
|
||||
|
||||
def _render_gui(self, mode):
|
||||
try:
|
||||
@@ -292,7 +282,7 @@ class FrozenLakeEnv(Env):
|
||||
pygame.display.init()
|
||||
pygame.display.set_caption("Frozen Lake")
|
||||
self.window_surface = pygame.display.set_mode(self.window_size)
|
||||
elif mode in {"rgb_array", "rgb_array_list"}:
|
||||
elif mode == "rgb_array":
|
||||
self.window_surface = pygame.Surface(self.window_size)
|
||||
|
||||
assert (
|
||||
@@ -370,7 +360,7 @@ class FrozenLakeEnv(Env):
|
||||
pygame.event.pump()
|
||||
pygame.display.update()
|
||||
self.clock.tick(self.metadata["render_fps"])
|
||||
elif mode in {"rgb_array", "rgb_array_list"}:
|
||||
elif mode == "rgb_array":
|
||||
return np.transpose(
|
||||
np.array(pygame.surfarray.pixels3d(self.window_surface)), axes=(1, 0, 2)
|
||||
)
|
||||
|
@@ -8,7 +8,6 @@ import numpy as np
|
||||
from gym import Env, spaces, utils
|
||||
from gym.envs.toy_text.utils import categorical_sample
|
||||
from gym.error import DependencyNotInstalled
|
||||
from gym.utils.renderer import Renderer
|
||||
|
||||
MAP = [
|
||||
"+---------+",
|
||||
@@ -122,7 +121,7 @@ class TaxiEnv(Env):
|
||||
"""
|
||||
|
||||
metadata = {
|
||||
"render_modes": ["human", "ansi", "rgb_array", "rgb_array_list"],
|
||||
"render_modes": ["human", "ansi", "rgb_array"],
|
||||
"render_fps": 4,
|
||||
}
|
||||
|
||||
@@ -192,7 +191,6 @@ class TaxiEnv(Env):
|
||||
self.observation_space = spaces.Discrete(num_states)
|
||||
|
||||
self.render_mode = render_mode
|
||||
self.renderer = Renderer(self.render_mode, self._render)
|
||||
|
||||
# pygame utils
|
||||
self.window = None
|
||||
@@ -259,7 +257,6 @@ class TaxiEnv(Env):
|
||||
p, s, r, t = transitions[i]
|
||||
self.s = s
|
||||
self.lastaction = a
|
||||
self.renderer.render_step()
|
||||
return (int(s), r, t, False, {"prob": p, "action_mask": self.action_mask(s)})
|
||||
|
||||
def reset(
|
||||
@@ -272,20 +269,14 @@ class TaxiEnv(Env):
|
||||
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
|
||||
self.lastaction = None
|
||||
self.taxi_orientation = 0
|
||||
self.renderer.reset()
|
||||
self.renderer.render_step()
|
||||
|
||||
return int(self.s), {"prob": 1.0, "action_mask": self.action_mask(self.s)}
|
||||
|
||||
def render(self):
|
||||
return self.renderer.get_renders()
|
||||
|
||||
def _render(self, mode):
|
||||
assert mode in self.metadata["render_modes"]
|
||||
if mode == "ansi":
|
||||
if self.render_mode == "ansi":
|
||||
return self._render_text()
|
||||
elif mode in {"human", "rgb_array", "rgb_array_list"}:
|
||||
return self._render_gui(mode)
|
||||
else: # self.render_mode in {"human", "rgb_array"}:
|
||||
return self._render_gui(self.render_mode)
|
||||
|
||||
def _render_gui(self, mode):
|
||||
try:
|
||||
@@ -300,7 +291,7 @@ class TaxiEnv(Env):
|
||||
pygame.display.set_caption("Taxi")
|
||||
if mode == "human":
|
||||
self.window = pygame.display.set_mode(WINDOW_SIZE)
|
||||
elif mode in {"rgb_array", "rgb_array_list"}:
|
||||
elif mode == "rgb_array":
|
||||
self.window = pygame.Surface(WINDOW_SIZE)
|
||||
|
||||
assert (
|
||||
@@ -412,7 +403,7 @@ class TaxiEnv(Env):
|
||||
if mode == "human":
|
||||
pygame.display.update()
|
||||
self.clock.tick(self.metadata["render_fps"])
|
||||
elif mode in {"rgb_array", "rgb_array_list"}:
|
||||
elif mode == "rgb_array":
|
||||
return np.transpose(
|
||||
np.array(pygame.surfarray.pixels3d(self.window)), axes=(1, 0, 2)
|
||||
)
|
||||
|
@@ -1,79 +0,0 @@
|
||||
"""A utility class to collect render frames from a function that computes a single frame."""
|
||||
from typing import Any, Callable, List, Optional, Set
|
||||
|
||||
# list of modes with which render function returns None
|
||||
NO_RETURNS_RENDER = {"human"}
|
||||
|
||||
# list of modes with which render returns just a single frame of the current state
|
||||
SINGLE_RENDER = {"rgb_array", "depth_array", "state_pixels", "ansi"}
|
||||
|
||||
|
||||
class Renderer:
|
||||
"""This class serves to easily integrate collection of renders for environments that can computes a single render.
|
||||
|
||||
To use this function:
|
||||
- instantiate this class with the mode and the function that computes a single frame
|
||||
- call render_step method each time the frame should be saved in the list
|
||||
(usually at the end of the step and reset methods)
|
||||
- call get_renders whenever you want to retrieve renders
|
||||
(usually in the render method)
|
||||
- call reset to clean the render list
|
||||
(usually in the reset method of the environment)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: Optional[str],
|
||||
render: Callable[[str], Any],
|
||||
no_returns_render: Optional[Set[str]] = None,
|
||||
single_render: Optional[Set[str]] = None,
|
||||
):
|
||||
"""Instantiates a Renderer object.
|
||||
|
||||
Args:
|
||||
mode (Optional[str]): Way to render
|
||||
render (Callable[[str], Any]): Function that receives the mode and computes a single frame
|
||||
no_returns_render (Optional[Set[str]]): Set of render modes that don't return any value.
|
||||
The default value is the set {"human"}.
|
||||
single_render (Optional[Set[str]]): Set of render modes that should return a single frame.
|
||||
The default value is the set {"rgb_array", "depth_array", "state_pixels", "ansi"}.
|
||||
"""
|
||||
if no_returns_render is None:
|
||||
no_returns_render = NO_RETURNS_RENDER
|
||||
if single_render is None:
|
||||
single_render = SINGLE_RENDER
|
||||
|
||||
self.no_returns_render = no_returns_render
|
||||
self.single_render = single_render
|
||||
self.mode = mode
|
||||
self.render = render
|
||||
self.render_list = []
|
||||
|
||||
def render_step(self) -> None:
|
||||
"""Computes a frame and save it to the render collection list.
|
||||
|
||||
This method should be usually called inside environment's step and reset method.
|
||||
"""
|
||||
if self.mode is not None and self.mode not in self.single_render:
|
||||
render_return = self.render(self.mode)
|
||||
if self.mode not in self.no_returns_render:
|
||||
self.render_list.append(render_return)
|
||||
|
||||
def get_renders(self) -> Optional[List]:
|
||||
"""Pops all the frames from the render collection list.
|
||||
|
||||
This method should be usually called in the environment's render method to retrieve the frames collected till this time step.
|
||||
"""
|
||||
if self.mode in self.single_render:
|
||||
return self.render(self.mode)
|
||||
elif self.mode is not None and self.mode not in self.no_returns_render:
|
||||
renders = self.render_list
|
||||
self.render_list = []
|
||||
return renders
|
||||
|
||||
def reset(self):
|
||||
"""Resets the render collection list.
|
||||
|
||||
This method should be usually called inside environment's reset method.
|
||||
"""
|
||||
self.render_list = []
|
@@ -12,6 +12,7 @@ from gym.wrappers.normalize import NormalizeObservation, NormalizeReward
|
||||
from gym.wrappers.order_enforcing import OrderEnforcing
|
||||
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
|
||||
from gym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule
|
||||
from gym.wrappers.render_collection import RenderCollection
|
||||
from gym.wrappers.rescale_action import RescaleAction
|
||||
from gym.wrappers.resize_observation import ResizeObservation
|
||||
from gym.wrappers.step_api_compatibility import StepAPICompatibility
|
||||
|
52
gym/wrappers/render_collection.py
Normal file
52
gym/wrappers/render_collection.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""A wrapper that adds render collection mode to an environment."""
|
||||
import gym
|
||||
|
||||
|
||||
class RenderCollection(gym.Wrapper):
|
||||
"""Save collection of render frames."""
|
||||
|
||||
def __init__(self, env: gym.Env, pop_frames: bool = True, reset_clean: bool = True):
|
||||
"""Initialize a :class:`RenderCollection` instance.
|
||||
|
||||
Args:
|
||||
env: The environment that is being wrapped
|
||||
pop_frames (bool): If true, clear the collection frames after .render() is called.
|
||||
Default value is True.
|
||||
reset_clean (bool): If true, clear the collection frames when .reset() is called.
|
||||
Default value is True.
|
||||
"""
|
||||
super().__init__(env)
|
||||
assert env.render_mode is not None
|
||||
assert not env.render_mode.endswith("_list")
|
||||
self.frame_list = []
|
||||
self.reset_clean = reset_clean
|
||||
self.pop_frames = pop_frames
|
||||
|
||||
@property
|
||||
def render_mode(self):
|
||||
"""Returns the collection render_mode name."""
|
||||
return f"{self.env.render_mode}_list"
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
"""Perform a step in the base environment and collect a frame."""
|
||||
output = self.env.step(*args, **kwargs)
|
||||
self.frame_list.append(self.env.render())
|
||||
return output
|
||||
|
||||
def reset(self, *args, **kwargs):
|
||||
"""Reset the base environment, eventually clear the frame_list, and collect a frame."""
|
||||
result = self.env.reset(*args, **kwargs)
|
||||
|
||||
if self.reset_clean:
|
||||
self.frame_list = []
|
||||
self.frame_list.append(self.env.render())
|
||||
|
||||
return result
|
||||
|
||||
def render(self):
|
||||
"""Returns the collection of frames and, if pop_frames = True, clears it."""
|
||||
frames = self.frame_list
|
||||
if self.pop_frames:
|
||||
self.frame_list = []
|
||||
|
||||
return frames
|
Reference in New Issue
Block a user