mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-07-31 13:54:31 +00:00
This commit is contained in:
committed by
GitHub
parent
a93da8f271
commit
34dfc9a728
@@ -49,6 +49,30 @@ register(
|
||||
max_episode_steps=500,
|
||||
)
|
||||
|
||||
|
||||
# Phys2d (jax classic control)
|
||||
# ----------------------------------------
|
||||
|
||||
register(
|
||||
id="CartPoleJax-v0",
|
||||
entry_point="gymnasium.envs.phys2d.cartpole:CartPoleJaxEnv",
|
||||
max_episode_steps=200,
|
||||
reward_threshold=195.0,
|
||||
)
|
||||
|
||||
register(
|
||||
id="CartPoleJax-v1",
|
||||
entry_point="gymnasium.envs.phys2d.cartpole:CartPoleJaxEnv",
|
||||
max_episode_steps=500,
|
||||
reward_threshold=475.0,
|
||||
)
|
||||
|
||||
register(
|
||||
id="PendulumJax-v0",
|
||||
entry_point="gymnasium.envs.phys2d.pendulum:PendulumJaxEnv",
|
||||
max_episode_steps=200,
|
||||
)
|
||||
|
||||
# Box2d
|
||||
# ----------------------------------------
|
||||
|
||||
|
2
gymnasium/envs/phys2d/__init__.py
Normal file
2
gymnasium/envs/phys2d/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from gymnasium.envs.phys2d.cartpole import CartPoleF
|
||||
from gymnasium.envs.phys2d.pendulum import PendulumF
|
BIN
gymnasium/envs/phys2d/assets/clockwise.png
Normal file
BIN
gymnasium/envs/phys2d/assets/clockwise.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 6.8 KiB |
252
gymnasium/envs/phys2d/cartpole.py
Normal file
252
gymnasium/envs/phys2d/cartpole.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""
|
||||
Implementation of a Jax-accelerated cartpole environment.
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.random import PRNGKey
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.phys2d.conversion import JaxEnv
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.utils import EzPickle
|
||||
|
||||
RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock"] # type: ignore # noqa: F821
|
||||
|
||||
|
||||
class CartPoleF(FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]):
|
||||
"""Cartpole but in jax and functional.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
key = jax.random.PRNGKey(0)
|
||||
|
||||
env = CartPole({"x_init": 0.5})
|
||||
state = env.initial(key)
|
||||
print(state)
|
||||
print(env.step(state, 0))
|
||||
|
||||
env.transform(jax.jit)
|
||||
|
||||
state = env.initial(key)
|
||||
print(state)
|
||||
print(env.step(state, 0))
|
||||
|
||||
vkey = jax.random.split(key, 10)
|
||||
env.transform(jax.vmap)
|
||||
vstate = env.initial(vkey)
|
||||
print(vstate)
|
||||
print(env.step(vstate, jnp.array([0 for _ in range(10)])))
|
||||
```
|
||||
"""
|
||||
|
||||
gravity = 9.8
|
||||
masscart = 1.0
|
||||
masspole = 0.1
|
||||
total_mass = masspole + masscart
|
||||
length = 0.5
|
||||
polemass_length = masspole + length
|
||||
force_mag = 10.0
|
||||
tau = 0.02
|
||||
theta_threshold_radians = 12 * 2 * np.pi / 360
|
||||
x_threshold = 2.4
|
||||
x_init = 0.05
|
||||
|
||||
screen_width = 600
|
||||
screen_height = 400
|
||||
|
||||
observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(4,), dtype=np.float32)
|
||||
action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def initial(self, rng: PRNGKey):
|
||||
"""Initial state generation."""
|
||||
return jax.random.uniform(
|
||||
key=rng, minval=-self.x_init, maxval=self.x_init, shape=(4,)
|
||||
)
|
||||
|
||||
def transition(
|
||||
self, state: jnp.ndarray, action: Union[int, jnp.ndarray], rng: None = None
|
||||
) -> StateType:
|
||||
"""Cartpole transition."""
|
||||
x, x_dot, theta, theta_dot = state
|
||||
force = jnp.sign(action - 0.5) * self.force_mag
|
||||
costheta = jnp.cos(theta)
|
||||
sintheta = jnp.sin(theta)
|
||||
|
||||
# For the interested reader:
|
||||
# https://coneural.org/florian/papers/05_cart_pole.pdf
|
||||
temp = (
|
||||
force + self.polemass_length * theta_dot**2 * sintheta
|
||||
) / self.total_mass
|
||||
thetaacc = (self.gravity * sintheta - costheta * temp) / (
|
||||
self.length * (4.0 / 3.0 - self.masspole * costheta**2 / self.total_mass)
|
||||
)
|
||||
xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass
|
||||
|
||||
x = x + self.tau * x_dot
|
||||
x_dot = x_dot + self.tau * xacc
|
||||
theta = theta + self.tau * theta_dot
|
||||
theta_dot = theta_dot + self.tau * thetaacc
|
||||
|
||||
state = jnp.array((x, x_dot, theta, theta_dot), dtype=jnp.float32)
|
||||
|
||||
return state
|
||||
|
||||
def observation(self, state: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Cartpole observation."""
|
||||
return state
|
||||
|
||||
def terminal(self, state: jnp.ndarray) -> jnp.ndarray:
|
||||
x, _, theta, _ = state
|
||||
|
||||
terminated = (
|
||||
(x < -self.x_threshold)
|
||||
| (x > self.x_threshold)
|
||||
| (theta < -self.theta_threshold_radians)
|
||||
| (theta > self.theta_threshold_radians)
|
||||
)
|
||||
|
||||
return terminated
|
||||
|
||||
def reward(
|
||||
self, state: StateType, action: ActType, next_state: StateType
|
||||
) -> jnp.ndarray:
|
||||
x, _, theta, _ = state
|
||||
|
||||
terminated = (
|
||||
(x < -self.x_threshold)
|
||||
| (x > self.x_threshold)
|
||||
| (theta < -self.theta_threshold_radians)
|
||||
| (theta > self.theta_threshold_radians)
|
||||
)
|
||||
|
||||
reward = jax.lax.cond(terminated, lambda: 0.0, lambda: 1.0)
|
||||
return reward
|
||||
|
||||
def render_image(
|
||||
self,
|
||||
state: StateType,
|
||||
render_state: RenderStateType,
|
||||
) -> Tuple[RenderStateType, np.ndarray]:
|
||||
|
||||
try:
|
||||
import pygame
|
||||
from pygame import gfxdraw
|
||||
except ImportError:
|
||||
raise DependencyNotInstalled(
|
||||
"pygame is not installed, run `pip install gymnasium[classic_control]`"
|
||||
)
|
||||
screen, clock = render_state
|
||||
|
||||
world_width = self.x_threshold * 2
|
||||
scale = self.screen_width / world_width
|
||||
polewidth = 10.0
|
||||
polelen = scale * (2 * self.length)
|
||||
cartwidth = 50.0
|
||||
cartheight = 30.0
|
||||
|
||||
x = state
|
||||
|
||||
surf = pygame.Surface((self.screen_width, self.screen_height))
|
||||
surf.fill((255, 255, 255))
|
||||
|
||||
l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2
|
||||
axleoffset = cartheight / 4.0
|
||||
cartx = x[0] * scale + self.screen_width / 2.0 # MIDDLE OF CART
|
||||
carty = 100 # TOP OF CART
|
||||
cart_coords = [(l, b), (l, t), (r, t), (r, b)]
|
||||
cart_coords = [(c[0] + cartx, c[1] + carty) for c in cart_coords]
|
||||
gfxdraw.aapolygon(surf, cart_coords, (0, 0, 0))
|
||||
gfxdraw.filled_polygon(surf, cart_coords, (0, 0, 0))
|
||||
|
||||
l, r, t, b = (
|
||||
-polewidth / 2,
|
||||
polewidth / 2,
|
||||
polelen - polewidth / 2,
|
||||
-polewidth / 2,
|
||||
)
|
||||
|
||||
pole_coords = []
|
||||
for coord in [(l, b), (l, t), (r, t), (r, b)]:
|
||||
coord = pygame.math.Vector2(coord).rotate_rad(-x[2])
|
||||
coord = (coord[0] + cartx, coord[1] + carty + axleoffset)
|
||||
pole_coords.append(coord)
|
||||
gfxdraw.aapolygon(surf, pole_coords, (202, 152, 101))
|
||||
gfxdraw.filled_polygon(surf, pole_coords, (202, 152, 101))
|
||||
|
||||
gfxdraw.aacircle(
|
||||
surf,
|
||||
int(cartx),
|
||||
int(carty + axleoffset),
|
||||
int(polewidth / 2),
|
||||
(129, 132, 203),
|
||||
)
|
||||
gfxdraw.filled_circle(
|
||||
surf,
|
||||
int(cartx),
|
||||
int(carty + axleoffset),
|
||||
int(polewidth / 2),
|
||||
(129, 132, 203),
|
||||
)
|
||||
|
||||
gfxdraw.hline(surf, 0, self.screen_width, carty, (0, 0, 0))
|
||||
|
||||
surf = pygame.transform.flip(surf, False, True)
|
||||
screen.blit(surf, (0, 0))
|
||||
|
||||
return (screen, clock), np.transpose(
|
||||
np.array(pygame.surfarray.pixels3d(screen)), axes=(1, 0, 2)
|
||||
)
|
||||
|
||||
def render_init(
|
||||
self, screen_width: int = 600, screen_height: int = 400
|
||||
) -> RenderStateType:
|
||||
try:
|
||||
import pygame
|
||||
except ImportError:
|
||||
raise DependencyNotInstalled(
|
||||
"pygame is not installed, run `pip install gymnasium[classic_control]`"
|
||||
)
|
||||
|
||||
pygame.init()
|
||||
screen = pygame.Surface((screen_width, screen_height))
|
||||
clock = pygame.time.Clock()
|
||||
|
||||
return screen, clock
|
||||
|
||||
def render_close(self, render_state: RenderStateType) -> None:
|
||||
try:
|
||||
import pygame
|
||||
except ImportError:
|
||||
raise DependencyNotInstalled(
|
||||
"pygame is not installed, run `pip install gymnasium[classic_control]`"
|
||||
)
|
||||
pygame.display.quit()
|
||||
pygame.quit()
|
||||
|
||||
|
||||
class CartPoleJaxEnv(JaxEnv, EzPickle):
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
|
||||
|
||||
def __init__(self, render_mode: Optional[str] = None, **kwargs):
|
||||
EzPickle.__init__(self, render_mode=render_mode, **kwargs)
|
||||
env = CartPoleF(**kwargs)
|
||||
env.transform(jax.jit)
|
||||
action_space = env.action_space
|
||||
observation_space = env.observation_space
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
|
||||
super().__init__(
|
||||
env,
|
||||
observation_space=observation_space,
|
||||
action_space=action_space,
|
||||
metadata=metadata,
|
||||
render_mode=render_mode,
|
||||
)
|
121
gymnasium/envs/phys2d/conversion.py
Normal file
121
gymnasium/envs/phys2d/conversion.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jrng
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import Space
|
||||
from gymnasium.envs.registration import EnvSpec
|
||||
from gymnasium.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.utils import seeding
|
||||
|
||||
|
||||
class JaxEnv(gym.Env):
|
||||
"""
|
||||
A conversion layer for numpy-based environments.
|
||||
"""
|
||||
|
||||
state: StateType
|
||||
rng: jrng.PRNGKey
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func_env: FuncEnv,
|
||||
observation_space: Space,
|
||||
action_space: Space,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
render_mode: Optional[str] = None,
|
||||
reward_range: Tuple[float, float] = (-float("inf"), float("inf")),
|
||||
spec: Optional[EnvSpec] = None,
|
||||
):
|
||||
"""Initialize the environment from a FuncEnv."""
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
self.func_env = func_env
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.metadata = metadata
|
||||
self.render_mode = render_mode
|
||||
self.reward_range = reward_range
|
||||
self.spec = spec
|
||||
|
||||
self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
|
||||
|
||||
if self.render_mode == "rgb_array":
|
||||
self.render_state = self.func_env.render_init()
|
||||
else:
|
||||
self.render_state = None
|
||||
|
||||
np_random, _ = seeding.np_random()
|
||||
seed = np_random.integers(0, 2**32 - 1, dtype="uint32")
|
||||
|
||||
self.rng = jrng.PRNGKey(seed)
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
super().reset(seed=seed)
|
||||
if seed is not None:
|
||||
self.rng = jrng.PRNGKey(seed)
|
||||
|
||||
rng, self.rng = jrng.split(self.rng)
|
||||
|
||||
self.state = self.func_env.initial(rng=rng)
|
||||
obs = self.func_env.observation(self.state)
|
||||
info = self.func_env.state_info(self.state)
|
||||
|
||||
obs = _convert_jax_to_numpy(obs)
|
||||
|
||||
return obs, info
|
||||
|
||||
def step(self, action: ActType):
|
||||
if self._is_box_action_space:
|
||||
assert isinstance(self.action_space, gym.spaces.Box) # For typing
|
||||
action = np.clip(action, self.action_space.low, self.action_space.high)
|
||||
else: # Discrete
|
||||
# For now we assume jax envs don't use complex spaces
|
||||
err_msg = f"{action!r} ({type(action)}) invalid"
|
||||
assert self.action_space.contains(action), err_msg
|
||||
|
||||
rng, self.rng = jrng.split(self.rng)
|
||||
|
||||
next_state = self.func_env.transition(self.state, action, rng)
|
||||
observation = self.func_env.observation(self.state)
|
||||
reward = self.func_env.reward(self.state, action, next_state)
|
||||
terminated = self.func_env.terminal(next_state)
|
||||
info = self.func_env.step_info(self.state, action, next_state)
|
||||
self.state = next_state
|
||||
|
||||
observation = _convert_jax_to_numpy(observation)
|
||||
|
||||
return observation, float(reward), bool(terminated), False, info
|
||||
|
||||
def render(self):
|
||||
if self.render_mode == "rgb_array":
|
||||
self.render_state, image = self.func_env.render_image(
|
||||
self.state, self.render_state
|
||||
)
|
||||
return image
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self):
|
||||
if self.render_state is not None:
|
||||
self.func_env.render_close(self.render_state)
|
||||
self.render_state = None
|
||||
|
||||
|
||||
def _convert_jax_to_numpy(element: Any):
|
||||
"""
|
||||
Convert a jax observation/action to a numpy array, or a numpy-based container.
|
||||
Currently required because all tests assume that stuff is in numpy arrays, hopefully will be removed soon.
|
||||
"""
|
||||
if isinstance(element, jnp.ndarray):
|
||||
return np.asarray(element)
|
||||
elif isinstance(element, tuple):
|
||||
return tuple(_convert_jax_to_numpy(e) for e in element)
|
||||
elif isinstance(element, list):
|
||||
return [_convert_jax_to_numpy(e) for e in element]
|
||||
elif isinstance(element, dict):
|
||||
return {k: _convert_jax_to_numpy(v) for k, v in element.items()}
|
||||
else:
|
||||
raise TypeError(f"Cannot convert {element} to numpy")
|
201
gymnasium/envs/phys2d/pendulum.py
Normal file
201
gymnasium/envs/phys2d/pendulum.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Implementation of a Jax-accelerated pendulum environment.
|
||||
"""
|
||||
from os import path
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.random import PRNGKey
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.phys2d.conversion import JaxEnv
|
||||
from gymnasium.error import DependencyNotInstalled
|
||||
from gymnasium.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.utils import EzPickle
|
||||
|
||||
RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]] # type: ignore # noqa: F821
|
||||
|
||||
|
||||
class PendulumF(FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]):
|
||||
"""Pendulum but in jax and functional."""
|
||||
|
||||
max_speed = 8
|
||||
max_torque = 2.0
|
||||
dt = 0.05
|
||||
g = 10.0
|
||||
m = 1.0
|
||||
l = 1.0
|
||||
high_x = jnp.pi
|
||||
high_y = 1.0
|
||||
|
||||
screen_dim = 500
|
||||
|
||||
observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(3,), dtype=np.float32)
|
||||
action_space = gym.spaces.Box(-max_torque, max_torque, shape=(1,), dtype=np.float32)
|
||||
|
||||
def initial(self, rng: PRNGKey):
|
||||
"""Initial state generation."""
|
||||
high = jnp.array([self.high_x, self.high_y])
|
||||
return jax.random.uniform(key=rng, minval=-high, maxval=high, shape=high.shape)
|
||||
|
||||
def transition(
|
||||
self, state: jnp.ndarray, action: Union[int, jnp.ndarray], rng: None = None
|
||||
) -> jnp.ndarray:
|
||||
"""Pendulum transition."""
|
||||
th, thdot = state # th := theta
|
||||
u = action
|
||||
|
||||
g = self.g
|
||||
m = self.m
|
||||
l = self.l
|
||||
dt = self.dt
|
||||
|
||||
u = jnp.clip(u, -self.max_torque, self.max_torque)[0]
|
||||
|
||||
newthdot = thdot + (3 * g / (2 * l) * jnp.sin(th) + 3.0 / (m * l**2) * u) * dt
|
||||
newthdot = jnp.clip(newthdot, -self.max_speed, self.max_speed)
|
||||
newth = th + newthdot * dt
|
||||
|
||||
new_state = jnp.array([newth, newthdot])
|
||||
return new_state
|
||||
|
||||
def observation(self, state: jnp.ndarray) -> jnp.ndarray:
|
||||
theta, thetadot = state
|
||||
return jnp.array([jnp.cos(theta), jnp.sin(theta), thetadot])
|
||||
|
||||
def reward(self, state: StateType, action: ActType, next_state: StateType) -> float:
|
||||
th, thdot = state # th := theta
|
||||
u = action
|
||||
|
||||
u = jnp.clip(u, -self.max_torque, self.max_torque)[0]
|
||||
|
||||
th_normalized = ((th + jnp.pi) % (2 * jnp.pi)) - jnp.pi
|
||||
costs = th_normalized**2 + 0.1 * thdot**2 + 0.001 * (u**2)
|
||||
|
||||
return -costs
|
||||
|
||||
def terminal(self, state: StateType) -> bool:
|
||||
return False
|
||||
|
||||
def render_image(
|
||||
self,
|
||||
state: StateType,
|
||||
render_state: Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]], # type: ignore # noqa: F821
|
||||
) -> Tuple[RenderStateType, np.ndarray]:
|
||||
try:
|
||||
import pygame
|
||||
from pygame import gfxdraw
|
||||
except ImportError:
|
||||
raise DependencyNotInstalled(
|
||||
"pygame is not installed, run `pip install gymnasium[classic_control]`"
|
||||
)
|
||||
screen, clock, last_u = render_state
|
||||
|
||||
surf = pygame.Surface((self.screen_dim, self.screen_dim))
|
||||
surf.fill((255, 255, 255))
|
||||
|
||||
bound = 2.2
|
||||
scale = self.screen_dim / (bound * 2)
|
||||
offset = self.screen_dim // 2
|
||||
|
||||
rod_length = 1 * scale
|
||||
rod_width = 0.2 * scale
|
||||
l, r, t, b = 0, rod_length, rod_width / 2, -rod_width / 2
|
||||
coords = [(l, b), (l, t), (r, t), (r, b)]
|
||||
transformed_coords = []
|
||||
for c in coords:
|
||||
c = pygame.math.Vector2(c).rotate_rad(state[0] + np.pi / 2)
|
||||
c = (c[0] + offset, c[1] + offset)
|
||||
transformed_coords.append(c)
|
||||
gfxdraw.aapolygon(surf, transformed_coords, (204, 77, 77))
|
||||
gfxdraw.filled_polygon(surf, transformed_coords, (204, 77, 77))
|
||||
|
||||
gfxdraw.aacircle(surf, offset, offset, int(rod_width / 2), (204, 77, 77))
|
||||
gfxdraw.filled_circle(surf, offset, offset, int(rod_width / 2), (204, 77, 77))
|
||||
|
||||
rod_end = (rod_length, 0)
|
||||
rod_end = pygame.math.Vector2(rod_end).rotate_rad(state[0] + np.pi / 2)
|
||||
rod_end = (int(rod_end[0] + offset), int(rod_end[1] + offset))
|
||||
gfxdraw.aacircle(
|
||||
surf, rod_end[0], rod_end[1], int(rod_width / 2), (204, 77, 77)
|
||||
)
|
||||
gfxdraw.filled_circle(
|
||||
surf, rod_end[0], rod_end[1], int(rod_width / 2), (204, 77, 77)
|
||||
)
|
||||
|
||||
fname = path.join(path.dirname(__file__), "assets/clockwise.png")
|
||||
img = pygame.image.load(fname)
|
||||
if last_u is not None:
|
||||
scale_img = pygame.transform.smoothscale(
|
||||
img,
|
||||
(scale * np.abs(last_u) / 2, scale * np.abs(last_u) / 2),
|
||||
)
|
||||
is_flip = bool(last_u > 0)
|
||||
scale_img = pygame.transform.flip(scale_img, is_flip, True)
|
||||
surf.blit(
|
||||
scale_img,
|
||||
(
|
||||
offset - scale_img.get_rect().centerx,
|
||||
offset - scale_img.get_rect().centery,
|
||||
),
|
||||
)
|
||||
|
||||
# drawing axle
|
||||
gfxdraw.aacircle(surf, offset, offset, int(0.05 * scale), (0, 0, 0))
|
||||
gfxdraw.filled_circle(surf, offset, offset, int(0.05 * scale), (0, 0, 0))
|
||||
|
||||
surf = pygame.transform.flip(surf, False, True)
|
||||
screen.blit(surf, (0, 0))
|
||||
|
||||
return (screen, clock, last_u), np.transpose(
|
||||
np.array(pygame.surfarray.pixels3d(screen)), axes=(1, 0, 2)
|
||||
)
|
||||
|
||||
def render_init(
|
||||
self, screen_width: int = 600, screen_height: int = 400
|
||||
) -> RenderStateType:
|
||||
try:
|
||||
import pygame
|
||||
except ImportError:
|
||||
raise DependencyNotInstalled(
|
||||
"pygame is not installed, run `pip install gymnasium[classic_control]`"
|
||||
)
|
||||
|
||||
pygame.init()
|
||||
screen = pygame.Surface((screen_width, screen_height))
|
||||
clock = pygame.time.Clock()
|
||||
|
||||
return screen, clock, None
|
||||
|
||||
def render_close(self, render_state: RenderStateType) -> None:
|
||||
try:
|
||||
import pygame
|
||||
except ImportError:
|
||||
raise DependencyNotInstalled(
|
||||
"pygame is not installed, run `pip install gymnasium[classic_control]`"
|
||||
)
|
||||
pygame.display.quit()
|
||||
pygame.quit()
|
||||
|
||||
|
||||
class PendulumJaxEnv(JaxEnv, EzPickle):
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 30}
|
||||
|
||||
def __init__(self, render_mode: Optional[str] = None, **kwargs):
|
||||
EzPickle.__init__(self, render_mode=render_mode, **kwargs)
|
||||
env = PendulumF(**kwargs)
|
||||
env.transform(jax.jit)
|
||||
action_space = env.action_space
|
||||
observation_space = env.observation_space
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 30}
|
||||
|
||||
super().__init__(
|
||||
env,
|
||||
observation_space=observation_space,
|
||||
action_space=action_space,
|
||||
metadata=metadata,
|
||||
render_mode=render_mode,
|
||||
)
|
96
gymnasium/functional.py
Normal file
96
gymnasium/functional.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""Base class and definitions for an alternative, functional backend for gym envs, particularly suitable for hardware accelerated and otherwise transformed environments."""
|
||||
|
||||
from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
StateType = TypeVar("StateType")
|
||||
ActType = TypeVar("ActType")
|
||||
ObsType = TypeVar("ObsType")
|
||||
RewardType = TypeVar("RewardType")
|
||||
TerminalType = TypeVar("TerminalType")
|
||||
RenderStateType = TypeVar("RenderStateType")
|
||||
|
||||
|
||||
class FuncEnv(
|
||||
Generic[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType]
|
||||
):
|
||||
"""Base class (template) for functional envs.
|
||||
|
||||
This API is meant to be used in a stateless manner, with the environment state being passed around explicitly.
|
||||
That being said, nothing here prevents users from using the environment statefully, it's just not recommended.
|
||||
A functional env consists of the following functions (in this case, instance methods):
|
||||
- initial: returns the initial state of the POMDP
|
||||
- observation: returns the observation in a given state
|
||||
- transition: returns the next state after taking an action in a given state
|
||||
- reward: returns the reward for a given (state, action, next_state) tuple
|
||||
- terminal: returns whether a given state is terminal
|
||||
- state_info: optional, returns a dict of info about a given state
|
||||
- step_info: optional, returns a dict of info about a given (state, action, next_state) tuple
|
||||
|
||||
The class-based structure serves the purpose of allowing environment constants to be defined in the class,
|
||||
and then using them by name in the code itself.
|
||||
|
||||
For the moment, this is predominantly for internal use. This API is likely to change, but in the future
|
||||
we intend to flesh it out and officially expose it to end users.
|
||||
"""
|
||||
|
||||
def __init__(self, options: Optional[Dict[str, Any]] = None):
|
||||
"""Initialize the environment constants."""
|
||||
self.__dict__.update(options or {})
|
||||
|
||||
def initial(self, rng: Any) -> StateType:
|
||||
"""Initial state."""
|
||||
raise NotImplementedError
|
||||
|
||||
def observation(self, state: StateType) -> ObsType:
|
||||
"""Observation."""
|
||||
raise NotImplementedError
|
||||
|
||||
def transition(self, state: StateType, action: ActType, rng: Any) -> StateType:
|
||||
"""Transition."""
|
||||
raise NotImplementedError
|
||||
|
||||
def reward(
|
||||
self, state: StateType, action: ActType, next_state: StateType
|
||||
) -> RewardType:
|
||||
"""Reward."""
|
||||
raise NotImplementedError
|
||||
|
||||
def terminal(self, state: StateType) -> TerminalType:
|
||||
"""Terminal state."""
|
||||
raise NotImplementedError
|
||||
|
||||
def state_info(self, state: StateType) -> dict:
|
||||
"""Info dict about a single state."""
|
||||
return {}
|
||||
|
||||
def step_info(
|
||||
self, state: StateType, action: ActType, next_state: StateType
|
||||
) -> dict:
|
||||
"""Info dict about a full transition."""
|
||||
return {}
|
||||
|
||||
def transform(self, func: Callable[[Callable], Callable]):
|
||||
"""Functional transformations."""
|
||||
self.initial = func(self.initial)
|
||||
self.transition = func(self.transition)
|
||||
self.observation = func(self.observation)
|
||||
self.reward = func(self.reward)
|
||||
self.terminal = func(self.terminal)
|
||||
self.state_info = func(self.state_info)
|
||||
self.step_info = func(self.step_info)
|
||||
|
||||
def render_image(
|
||||
self, state: StateType, render_state: RenderStateType
|
||||
) -> Tuple[RenderStateType, np.ndarray]:
|
||||
"""Show the state."""
|
||||
raise NotImplementedError
|
||||
|
||||
def render_init(self, **kwargs) -> RenderStateType:
|
||||
"""Initialize the render state."""
|
||||
raise NotImplementedError
|
||||
|
||||
def render_close(self, render_state: RenderStateType):
|
||||
"""Close the render state."""
|
||||
raise NotImplementedError
|
@@ -12,4 +12,4 @@ pygame==2.1.0
|
||||
ale-py~=0.8.0
|
||||
mujoco==2.2
|
||||
mujoco_py<2.2,>=2.1
|
||||
imageio>=2.14.1
|
||||
imageio>=2.14.1
|
||||
|
1
setup.py
1
setup.py
@@ -40,6 +40,7 @@ extras = {
|
||||
"mujoco_py": ["mujoco_py<2.2,>=2.1"],
|
||||
"mujoco": ["mujoco==2.2", "imageio>=2.14.1"],
|
||||
"toy_text": ["pygame==2.1.0"],
|
||||
"jax": ["jax==0.3.20", "jaxlib==0.3.20"],
|
||||
"other": ["lz4>=3.1.0", "opencv-python>=3.0", "matplotlib>=3.0", "moviepy>=1.0.0"],
|
||||
}
|
||||
|
||||
|
0
tests/envs/functional/__init__.py
Normal file
0
tests/envs/functional/__init__.py
Normal file
56
tests/envs/functional/test_core.py
Normal file
56
tests/envs/functional/test_core.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gymnasium.functional import FuncEnv
|
||||
|
||||
|
||||
class TestEnv(FuncEnv):
|
||||
def __init__(self, options: Optional[Dict[str, Any]] = None):
|
||||
super().__init__(options)
|
||||
|
||||
def initial(self, rng: Any) -> np.ndarray:
|
||||
return np.array([0, 0], dtype=np.float32)
|
||||
|
||||
def observation(self, state: np.ndarray) -> np.ndarray:
|
||||
return state
|
||||
|
||||
def transition(self, state: np.ndarray, action: int, rng: None) -> np.ndarray:
|
||||
return state + np.array([0, action], dtype=np.float32)
|
||||
|
||||
def reward(self, state: np.ndarray, action: int, next_state: np.ndarray) -> float:
|
||||
return 1.0 if next_state[1] > 0 else 0.0
|
||||
|
||||
def terminal(self, state: np.ndarray) -> bool:
|
||||
return state[1] > 0
|
||||
|
||||
|
||||
def test_api():
|
||||
env = TestEnv()
|
||||
state = env.initial(None)
|
||||
obs = env.observation(state)
|
||||
assert state.shape == (2,)
|
||||
assert state.dtype == np.float32
|
||||
assert obs.shape == (2,)
|
||||
assert obs.dtype == np.float32
|
||||
assert np.allclose(obs, state)
|
||||
|
||||
actions = [-1, -2, -5, 3, 5, 2]
|
||||
for i, action in enumerate(actions):
|
||||
next_state = env.transition(state, action, None)
|
||||
assert next_state.shape == (2,)
|
||||
assert next_state.dtype == np.float32
|
||||
assert np.allclose(next_state, state + np.array([0, action]))
|
||||
|
||||
observation = env.observation(next_state)
|
||||
assert observation.shape == (2,)
|
||||
assert observation.dtype == np.float32
|
||||
assert np.allclose(observation, next_state)
|
||||
|
||||
reward = env.reward(state, action, next_state)
|
||||
assert reward == (1.0 if next_state[1] > 0 else 0.0)
|
||||
|
||||
terminal = env.terminal(next_state)
|
||||
assert terminal == (i == 5) # terminal state is in the final action
|
||||
|
||||
state = next_state
|
105
tests/envs/functional/test_jax.py
Normal file
105
tests/envs/functional/test_jax.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jrng
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from gymnasium.envs.phys2d.cartpole import CartPoleF # noqa: E402
|
||||
from gymnasium.envs.phys2d.pendulum import PendulumF # noqa: E402
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_class", [CartPoleF, PendulumF])
|
||||
def test_normal(env_class):
|
||||
env = env_class()
|
||||
rng = jrng.PRNGKey(0)
|
||||
|
||||
state = env.initial(rng)
|
||||
env.action_space.seed(0)
|
||||
|
||||
for t in range(10):
|
||||
obs = env.observation(state)
|
||||
action = env.action_space.sample()
|
||||
next_state = env.transition(state, action, None)
|
||||
reward = env.reward(state, action, next_state)
|
||||
terminal = env.terminal(next_state)
|
||||
|
||||
assert next_state.shape == state.shape
|
||||
try:
|
||||
float(reward)
|
||||
except ValueError:
|
||||
pytest.fail("Reward is not castable to float")
|
||||
try:
|
||||
bool(terminal)
|
||||
except ValueError:
|
||||
pytest.fail("Terminal is not castable to bool")
|
||||
|
||||
assert next_state.dtype == jnp.float32
|
||||
assert isinstance(obs, jnp.ndarray)
|
||||
assert obs.dtype == jnp.float32
|
||||
|
||||
state = next_state
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_class", [CartPoleF, PendulumF])
|
||||
def test_jit(env_class):
|
||||
env = env_class()
|
||||
rng = jrng.PRNGKey(0)
|
||||
|
||||
env.transform(jax.jit)
|
||||
state = env.initial(rng)
|
||||
env.action_space.seed(0)
|
||||
|
||||
for t in range(10):
|
||||
obs = env.observation(state)
|
||||
action = env.action_space.sample()
|
||||
next_state = env.transition(state, action, None)
|
||||
reward = env.reward(state, action, next_state)
|
||||
terminal = env.terminal(next_state)
|
||||
|
||||
assert next_state.shape == state.shape
|
||||
try:
|
||||
float(reward)
|
||||
except ValueError:
|
||||
pytest.fail("Reward is not castable to float")
|
||||
try:
|
||||
bool(terminal)
|
||||
except ValueError:
|
||||
pytest.fail("Terminal is not castable to bool")
|
||||
|
||||
assert next_state.dtype == jnp.float32
|
||||
assert isinstance(obs, jnp.ndarray)
|
||||
assert obs.dtype == jnp.float32
|
||||
|
||||
state = next_state
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_class", [CartPoleF, PendulumF])
|
||||
def test_vmap(env_class):
|
||||
env = env_class()
|
||||
num_envs = 10
|
||||
rng = jrng.split(jrng.PRNGKey(0), num_envs)
|
||||
|
||||
env.transform(jax.vmap)
|
||||
env.transform(jax.jit)
|
||||
state = env.initial(rng)
|
||||
env.action_space.seed(0)
|
||||
|
||||
for t in range(10):
|
||||
obs = env.observation(state)
|
||||
action = jnp.array([env.action_space.sample() for _ in range(num_envs)])
|
||||
# if isinstance(env.action_space, Discrete):
|
||||
# action = action.reshape((num_envs, 1))
|
||||
next_state = env.transition(state, action, None)
|
||||
terminal = env.terminal(next_state)
|
||||
reward = env.reward(state, action, next_state)
|
||||
|
||||
assert next_state.shape == state.shape
|
||||
assert next_state.dtype == jnp.float32
|
||||
assert reward.shape == (num_envs,)
|
||||
assert reward.dtype == jnp.float32
|
||||
assert terminal.shape == (num_envs,)
|
||||
assert terminal.dtype == np.bool
|
||||
assert isinstance(obs, jnp.ndarray)
|
||||
assert obs.dtype == jnp.float32
|
||||
|
||||
state = next_state
|
@@ -171,7 +171,6 @@ def test_render_modes(spec):
|
||||
env = spec.make()
|
||||
|
||||
assert "rgb_array" in env.metadata["render_modes"]
|
||||
assert "human" in env.metadata["render_modes"]
|
||||
|
||||
for mode in env.metadata["render_modes"]:
|
||||
if mode != "human":
|
||||
|
@@ -18,173 +18,3 @@ CartPole-v1
|
||||
|
||||
"""
|
||||
assert out == correct_out
|
||||
|
||||
|
||||
def test_pprint_registry():
|
||||
"""Testing the default registry, with no changes."""
|
||||
out = gym.pprint_registry(disable_print=True)
|
||||
|
||||
correct_out = """===== classic_control =====
|
||||
Acrobot-v1
|
||||
CartPole-v0
|
||||
CartPole-v1
|
||||
MountainCar-v0
|
||||
MountainCarContinuous-v0
|
||||
Pendulum-v1
|
||||
|
||||
===== box2d =====
|
||||
BipedalWalker-v3
|
||||
BipedalWalkerHardcore-v3
|
||||
CarRacing-v2
|
||||
LunarLander-v2
|
||||
LunarLanderContinuous-v2
|
||||
|
||||
===== toy_text =====
|
||||
Blackjack-v1
|
||||
CliffWalking-v0
|
||||
FrozenLake-v1
|
||||
FrozenLake8x8-v1
|
||||
Taxi-v3
|
||||
|
||||
===== mujoco =====
|
||||
Ant-v2 Ant-v3 Ant-v4
|
||||
HalfCheetah-v2 HalfCheetah-v3 HalfCheetah-v4
|
||||
Hopper-v2 Hopper-v3 Hopper-v4
|
||||
Humanoid-v2 Humanoid-v3 Humanoid-v4
|
||||
HumanoidStandup-v2 HumanoidStandup-v4 InvertedDoublePendulum-v2
|
||||
InvertedDoublePendulum-v4 InvertedPendulum-v2 InvertedPendulum-v4
|
||||
Pusher-v2 Pusher-v4 Reacher-v2
|
||||
Reacher-v4 Swimmer-v2 Swimmer-v3
|
||||
Swimmer-v4 Walker2d-v2 Walker2d-v3
|
||||
Walker2d-v4
|
||||
|
||||
===== external =====
|
||||
GymV26Environment-v0
|
||||
|
||||
===== utils_envs =====
|
||||
RegisterDuringMakeEnv-v0
|
||||
test.ArgumentEnv-v0
|
||||
test.OrderlessArgumentEnv-v0
|
||||
|
||||
===== test =====
|
||||
test/NoHuman-v0
|
||||
test/NoHumanNoRGB-v0
|
||||
test/NoHumanOldAPI-v0
|
||||
|
||||
"""
|
||||
assert out == correct_out
|
||||
|
||||
|
||||
def test_pprint_registry_exclude_namespaces():
|
||||
"""Testing the default registry, with no changes."""
|
||||
out = gym.pprint_registry(
|
||||
max_rows=20, exclude_namespaces=["classic_control"], disable_print=True
|
||||
)
|
||||
|
||||
correct_out = """===== box2d =====
|
||||
BipedalWalker-v3
|
||||
BipedalWalkerHardcore-v3
|
||||
CarRacing-v2
|
||||
LunarLander-v2
|
||||
LunarLanderContinuous-v2
|
||||
|
||||
===== toy_text =====
|
||||
Blackjack-v1
|
||||
CliffWalking-v0
|
||||
FrozenLake-v1
|
||||
FrozenLake8x8-v1
|
||||
Taxi-v3
|
||||
|
||||
===== mujoco =====
|
||||
Ant-v2 Ant-v3
|
||||
Ant-v4 HalfCheetah-v2
|
||||
HalfCheetah-v3 HalfCheetah-v4
|
||||
Hopper-v2 Hopper-v3
|
||||
Hopper-v4 Humanoid-v2
|
||||
Humanoid-v3 Humanoid-v4
|
||||
HumanoidStandup-v2 HumanoidStandup-v4
|
||||
InvertedDoublePendulum-v2 InvertedDoublePendulum-v4
|
||||
InvertedPendulum-v2 InvertedPendulum-v4
|
||||
Pusher-v2 Pusher-v4
|
||||
Reacher-v2 Reacher-v4
|
||||
Swimmer-v2 Swimmer-v3
|
||||
Swimmer-v4 Walker2d-v2
|
||||
Walker2d-v3 Walker2d-v4
|
||||
|
||||
===== external =====
|
||||
GymV26Environment-v0
|
||||
|
||||
===== utils_envs =====
|
||||
RegisterDuringMakeEnv-v0
|
||||
test.ArgumentEnv-v0
|
||||
test.OrderlessArgumentEnv-v0
|
||||
|
||||
===== test =====
|
||||
test/NoHuman-v0
|
||||
test/NoHumanNoRGB-v0
|
||||
test/NoHumanOldAPI-v0
|
||||
|
||||
"""
|
||||
assert out == correct_out
|
||||
|
||||
|
||||
def test_pprint_registry_no_entry_point():
|
||||
"""Test registry if there is environment with no entry point."""
|
||||
|
||||
gym.register("NoNamespaceEnv", "no-entry-point")
|
||||
out = gym.pprint_registry(disable_print=True)
|
||||
|
||||
correct_out = """===== classic_control =====
|
||||
Acrobot-v1
|
||||
CartPole-v0
|
||||
CartPole-v1
|
||||
MountainCar-v0
|
||||
MountainCarContinuous-v0
|
||||
Pendulum-v1
|
||||
|
||||
===== box2d =====
|
||||
BipedalWalker-v3
|
||||
BipedalWalkerHardcore-v3
|
||||
CarRacing-v2
|
||||
LunarLander-v2
|
||||
LunarLanderContinuous-v2
|
||||
|
||||
===== toy_text =====
|
||||
Blackjack-v1
|
||||
CliffWalking-v0
|
||||
FrozenLake-v1
|
||||
FrozenLake8x8-v1
|
||||
Taxi-v3
|
||||
|
||||
===== mujoco =====
|
||||
Ant-v2 Ant-v3 Ant-v4
|
||||
HalfCheetah-v2 HalfCheetah-v3 HalfCheetah-v4
|
||||
Hopper-v2 Hopper-v3 Hopper-v4
|
||||
Humanoid-v2 Humanoid-v3 Humanoid-v4
|
||||
HumanoidStandup-v2 HumanoidStandup-v4 InvertedDoublePendulum-v2
|
||||
InvertedDoublePendulum-v4 InvertedPendulum-v2 InvertedPendulum-v4
|
||||
Pusher-v2 Pusher-v4 Reacher-v2
|
||||
Reacher-v4 Swimmer-v2 Swimmer-v3
|
||||
Swimmer-v4 Walker2d-v2 Walker2d-v3
|
||||
Walker2d-v4
|
||||
|
||||
===== external =====
|
||||
GymV26Environment-v0
|
||||
|
||||
===== utils_envs =====
|
||||
RegisterDuringMakeEnv-v0
|
||||
test.ArgumentEnv-v0
|
||||
test.OrderlessArgumentEnv-v0
|
||||
|
||||
===== test =====
|
||||
test/NoHuman-v0
|
||||
test/NoHumanNoRGB-v0
|
||||
test/NoHumanOldAPI-v0
|
||||
|
||||
===== NoNamespaceEnv =====
|
||||
NoNamespaceEnv
|
||||
|
||||
"""
|
||||
assert out == correct_out
|
||||
|
||||
del gym.envs.registry["NoNamespaceEnv"]
|
||||
|
@@ -2,7 +2,7 @@ import pytest
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
|
||||
from gymnasium.wrappers import OrderEnforcing, TimeLimit, TransformObservation
|
||||
from gymnasium.wrappers import TimeLimit, TransformObservation
|
||||
from gymnasium.wrappers.env_checker import PassiveEnvChecker
|
||||
from tests.wrappers.utils import has_wrapper
|
||||
|
||||
@@ -39,8 +39,6 @@ def test_vector_make_wrappers():
|
||||
sub_env = env.envs[0]
|
||||
assert isinstance(sub_env, gym.Env)
|
||||
assert sub_env.spec is not None
|
||||
if sub_env.spec.order_enforce:
|
||||
assert has_wrapper(sub_env, OrderEnforcing)
|
||||
if sub_env.spec.max_episode_steps is not None:
|
||||
assert has_wrapper(sub_env, TimeLimit)
|
||||
|
||||
|
Reference in New Issue
Block a user