mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-01 06:07:08 +00:00
This commit is contained in:
committed by
GitHub
parent
a93da8f271
commit
34dfc9a728
@@ -49,6 +49,30 @@ register(
|
|||||||
max_episode_steps=500,
|
max_episode_steps=500,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Phys2d (jax classic control)
|
||||||
|
# ----------------------------------------
|
||||||
|
|
||||||
|
register(
|
||||||
|
id="CartPoleJax-v0",
|
||||||
|
entry_point="gymnasium.envs.phys2d.cartpole:CartPoleJaxEnv",
|
||||||
|
max_episode_steps=200,
|
||||||
|
reward_threshold=195.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
register(
|
||||||
|
id="CartPoleJax-v1",
|
||||||
|
entry_point="gymnasium.envs.phys2d.cartpole:CartPoleJaxEnv",
|
||||||
|
max_episode_steps=500,
|
||||||
|
reward_threshold=475.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
register(
|
||||||
|
id="PendulumJax-v0",
|
||||||
|
entry_point="gymnasium.envs.phys2d.pendulum:PendulumJaxEnv",
|
||||||
|
max_episode_steps=200,
|
||||||
|
)
|
||||||
|
|
||||||
# Box2d
|
# Box2d
|
||||||
# ----------------------------------------
|
# ----------------------------------------
|
||||||
|
|
||||||
|
2
gymnasium/envs/phys2d/__init__.py
Normal file
2
gymnasium/envs/phys2d/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
from gymnasium.envs.phys2d.cartpole import CartPoleF
|
||||||
|
from gymnasium.envs.phys2d.pendulum import PendulumF
|
BIN
gymnasium/envs/phys2d/assets/clockwise.png
Normal file
BIN
gymnasium/envs/phys2d/assets/clockwise.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 6.8 KiB |
252
gymnasium/envs/phys2d/cartpole.py
Normal file
252
gymnasium/envs/phys2d/cartpole.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
"""
|
||||||
|
Implementation of a Jax-accelerated cartpole environment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium.envs.phys2d.conversion import JaxEnv
|
||||||
|
from gymnasium.error import DependencyNotInstalled
|
||||||
|
from gymnasium.functional import ActType, FuncEnv, StateType
|
||||||
|
from gymnasium.utils import EzPickle
|
||||||
|
|
||||||
|
RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock"] # type: ignore # noqa: F821
|
||||||
|
|
||||||
|
|
||||||
|
class CartPoleF(FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]):
|
||||||
|
"""Cartpole but in jax and functional.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
key = jax.random.PRNGKey(0)
|
||||||
|
|
||||||
|
env = CartPole({"x_init": 0.5})
|
||||||
|
state = env.initial(key)
|
||||||
|
print(state)
|
||||||
|
print(env.step(state, 0))
|
||||||
|
|
||||||
|
env.transform(jax.jit)
|
||||||
|
|
||||||
|
state = env.initial(key)
|
||||||
|
print(state)
|
||||||
|
print(env.step(state, 0))
|
||||||
|
|
||||||
|
vkey = jax.random.split(key, 10)
|
||||||
|
env.transform(jax.vmap)
|
||||||
|
vstate = env.initial(vkey)
|
||||||
|
print(vstate)
|
||||||
|
print(env.step(vstate, jnp.array([0 for _ in range(10)])))
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
gravity = 9.8
|
||||||
|
masscart = 1.0
|
||||||
|
masspole = 0.1
|
||||||
|
total_mass = masspole + masscart
|
||||||
|
length = 0.5
|
||||||
|
polemass_length = masspole + length
|
||||||
|
force_mag = 10.0
|
||||||
|
tau = 0.02
|
||||||
|
theta_threshold_radians = 12 * 2 * np.pi / 360
|
||||||
|
x_threshold = 2.4
|
||||||
|
x_init = 0.05
|
||||||
|
|
||||||
|
screen_width = 600
|
||||||
|
screen_height = 400
|
||||||
|
|
||||||
|
observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(4,), dtype=np.float32)
|
||||||
|
action_space = gym.spaces.Discrete(2)
|
||||||
|
|
||||||
|
def initial(self, rng: PRNGKey):
|
||||||
|
"""Initial state generation."""
|
||||||
|
return jax.random.uniform(
|
||||||
|
key=rng, minval=-self.x_init, maxval=self.x_init, shape=(4,)
|
||||||
|
)
|
||||||
|
|
||||||
|
def transition(
|
||||||
|
self, state: jnp.ndarray, action: Union[int, jnp.ndarray], rng: None = None
|
||||||
|
) -> StateType:
|
||||||
|
"""Cartpole transition."""
|
||||||
|
x, x_dot, theta, theta_dot = state
|
||||||
|
force = jnp.sign(action - 0.5) * self.force_mag
|
||||||
|
costheta = jnp.cos(theta)
|
||||||
|
sintheta = jnp.sin(theta)
|
||||||
|
|
||||||
|
# For the interested reader:
|
||||||
|
# https://coneural.org/florian/papers/05_cart_pole.pdf
|
||||||
|
temp = (
|
||||||
|
force + self.polemass_length * theta_dot**2 * sintheta
|
||||||
|
) / self.total_mass
|
||||||
|
thetaacc = (self.gravity * sintheta - costheta * temp) / (
|
||||||
|
self.length * (4.0 / 3.0 - self.masspole * costheta**2 / self.total_mass)
|
||||||
|
)
|
||||||
|
xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass
|
||||||
|
|
||||||
|
x = x + self.tau * x_dot
|
||||||
|
x_dot = x_dot + self.tau * xacc
|
||||||
|
theta = theta + self.tau * theta_dot
|
||||||
|
theta_dot = theta_dot + self.tau * thetaacc
|
||||||
|
|
||||||
|
state = jnp.array((x, x_dot, theta, theta_dot), dtype=jnp.float32)
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
def observation(self, state: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
"""Cartpole observation."""
|
||||||
|
return state
|
||||||
|
|
||||||
|
def terminal(self, state: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
x, _, theta, _ = state
|
||||||
|
|
||||||
|
terminated = (
|
||||||
|
(x < -self.x_threshold)
|
||||||
|
| (x > self.x_threshold)
|
||||||
|
| (theta < -self.theta_threshold_radians)
|
||||||
|
| (theta > self.theta_threshold_radians)
|
||||||
|
)
|
||||||
|
|
||||||
|
return terminated
|
||||||
|
|
||||||
|
def reward(
|
||||||
|
self, state: StateType, action: ActType, next_state: StateType
|
||||||
|
) -> jnp.ndarray:
|
||||||
|
x, _, theta, _ = state
|
||||||
|
|
||||||
|
terminated = (
|
||||||
|
(x < -self.x_threshold)
|
||||||
|
| (x > self.x_threshold)
|
||||||
|
| (theta < -self.theta_threshold_radians)
|
||||||
|
| (theta > self.theta_threshold_radians)
|
||||||
|
)
|
||||||
|
|
||||||
|
reward = jax.lax.cond(terminated, lambda: 0.0, lambda: 1.0)
|
||||||
|
return reward
|
||||||
|
|
||||||
|
def render_image(
|
||||||
|
self,
|
||||||
|
state: StateType,
|
||||||
|
render_state: RenderStateType,
|
||||||
|
) -> Tuple[RenderStateType, np.ndarray]:
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pygame
|
||||||
|
from pygame import gfxdraw
|
||||||
|
except ImportError:
|
||||||
|
raise DependencyNotInstalled(
|
||||||
|
"pygame is not installed, run `pip install gymnasium[classic_control]`"
|
||||||
|
)
|
||||||
|
screen, clock = render_state
|
||||||
|
|
||||||
|
world_width = self.x_threshold * 2
|
||||||
|
scale = self.screen_width / world_width
|
||||||
|
polewidth = 10.0
|
||||||
|
polelen = scale * (2 * self.length)
|
||||||
|
cartwidth = 50.0
|
||||||
|
cartheight = 30.0
|
||||||
|
|
||||||
|
x = state
|
||||||
|
|
||||||
|
surf = pygame.Surface((self.screen_width, self.screen_height))
|
||||||
|
surf.fill((255, 255, 255))
|
||||||
|
|
||||||
|
l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2
|
||||||
|
axleoffset = cartheight / 4.0
|
||||||
|
cartx = x[0] * scale + self.screen_width / 2.0 # MIDDLE OF CART
|
||||||
|
carty = 100 # TOP OF CART
|
||||||
|
cart_coords = [(l, b), (l, t), (r, t), (r, b)]
|
||||||
|
cart_coords = [(c[0] + cartx, c[1] + carty) for c in cart_coords]
|
||||||
|
gfxdraw.aapolygon(surf, cart_coords, (0, 0, 0))
|
||||||
|
gfxdraw.filled_polygon(surf, cart_coords, (0, 0, 0))
|
||||||
|
|
||||||
|
l, r, t, b = (
|
||||||
|
-polewidth / 2,
|
||||||
|
polewidth / 2,
|
||||||
|
polelen - polewidth / 2,
|
||||||
|
-polewidth / 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
pole_coords = []
|
||||||
|
for coord in [(l, b), (l, t), (r, t), (r, b)]:
|
||||||
|
coord = pygame.math.Vector2(coord).rotate_rad(-x[2])
|
||||||
|
coord = (coord[0] + cartx, coord[1] + carty + axleoffset)
|
||||||
|
pole_coords.append(coord)
|
||||||
|
gfxdraw.aapolygon(surf, pole_coords, (202, 152, 101))
|
||||||
|
gfxdraw.filled_polygon(surf, pole_coords, (202, 152, 101))
|
||||||
|
|
||||||
|
gfxdraw.aacircle(
|
||||||
|
surf,
|
||||||
|
int(cartx),
|
||||||
|
int(carty + axleoffset),
|
||||||
|
int(polewidth / 2),
|
||||||
|
(129, 132, 203),
|
||||||
|
)
|
||||||
|
gfxdraw.filled_circle(
|
||||||
|
surf,
|
||||||
|
int(cartx),
|
||||||
|
int(carty + axleoffset),
|
||||||
|
int(polewidth / 2),
|
||||||
|
(129, 132, 203),
|
||||||
|
)
|
||||||
|
|
||||||
|
gfxdraw.hline(surf, 0, self.screen_width, carty, (0, 0, 0))
|
||||||
|
|
||||||
|
surf = pygame.transform.flip(surf, False, True)
|
||||||
|
screen.blit(surf, (0, 0))
|
||||||
|
|
||||||
|
return (screen, clock), np.transpose(
|
||||||
|
np.array(pygame.surfarray.pixels3d(screen)), axes=(1, 0, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
def render_init(
|
||||||
|
self, screen_width: int = 600, screen_height: int = 400
|
||||||
|
) -> RenderStateType:
|
||||||
|
try:
|
||||||
|
import pygame
|
||||||
|
except ImportError:
|
||||||
|
raise DependencyNotInstalled(
|
||||||
|
"pygame is not installed, run `pip install gymnasium[classic_control]`"
|
||||||
|
)
|
||||||
|
|
||||||
|
pygame.init()
|
||||||
|
screen = pygame.Surface((screen_width, screen_height))
|
||||||
|
clock = pygame.time.Clock()
|
||||||
|
|
||||||
|
return screen, clock
|
||||||
|
|
||||||
|
def render_close(self, render_state: RenderStateType) -> None:
|
||||||
|
try:
|
||||||
|
import pygame
|
||||||
|
except ImportError:
|
||||||
|
raise DependencyNotInstalled(
|
||||||
|
"pygame is not installed, run `pip install gymnasium[classic_control]`"
|
||||||
|
)
|
||||||
|
pygame.display.quit()
|
||||||
|
pygame.quit()
|
||||||
|
|
||||||
|
|
||||||
|
class CartPoleJaxEnv(JaxEnv, EzPickle):
|
||||||
|
|
||||||
|
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
|
||||||
|
|
||||||
|
def __init__(self, render_mode: Optional[str] = None, **kwargs):
|
||||||
|
EzPickle.__init__(self, render_mode=render_mode, **kwargs)
|
||||||
|
env = CartPoleF(**kwargs)
|
||||||
|
env.transform(jax.jit)
|
||||||
|
action_space = env.action_space
|
||||||
|
observation_space = env.observation_space
|
||||||
|
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
|
||||||
|
super().__init__(
|
||||||
|
env,
|
||||||
|
observation_space=observation_space,
|
||||||
|
action_space=action_space,
|
||||||
|
metadata=metadata,
|
||||||
|
render_mode=render_mode,
|
||||||
|
)
|
121
gymnasium/envs/phys2d/conversion.py
Normal file
121
gymnasium/envs/phys2d/conversion.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jax.random as jrng
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium import Space
|
||||||
|
from gymnasium.envs.registration import EnvSpec
|
||||||
|
from gymnasium.functional import ActType, FuncEnv, StateType
|
||||||
|
from gymnasium.utils import seeding
|
||||||
|
|
||||||
|
|
||||||
|
class JaxEnv(gym.Env):
|
||||||
|
"""
|
||||||
|
A conversion layer for numpy-based environments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
state: StateType
|
||||||
|
rng: jrng.PRNGKey
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
func_env: FuncEnv,
|
||||||
|
observation_space: Space,
|
||||||
|
action_space: Space,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
render_mode: Optional[str] = None,
|
||||||
|
reward_range: Tuple[float, float] = (-float("inf"), float("inf")),
|
||||||
|
spec: Optional[EnvSpec] = None,
|
||||||
|
):
|
||||||
|
"""Initialize the environment from a FuncEnv."""
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
self.func_env = func_env
|
||||||
|
self.observation_space = observation_space
|
||||||
|
self.action_space = action_space
|
||||||
|
self.metadata = metadata
|
||||||
|
self.render_mode = render_mode
|
||||||
|
self.reward_range = reward_range
|
||||||
|
self.spec = spec
|
||||||
|
|
||||||
|
self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
|
||||||
|
|
||||||
|
if self.render_mode == "rgb_array":
|
||||||
|
self.render_state = self.func_env.render_init()
|
||||||
|
else:
|
||||||
|
self.render_state = None
|
||||||
|
|
||||||
|
np_random, _ = seeding.np_random()
|
||||||
|
seed = np_random.integers(0, 2**32 - 1, dtype="uint32")
|
||||||
|
|
||||||
|
self.rng = jrng.PRNGKey(seed)
|
||||||
|
|
||||||
|
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||||
|
super().reset(seed=seed)
|
||||||
|
if seed is not None:
|
||||||
|
self.rng = jrng.PRNGKey(seed)
|
||||||
|
|
||||||
|
rng, self.rng = jrng.split(self.rng)
|
||||||
|
|
||||||
|
self.state = self.func_env.initial(rng=rng)
|
||||||
|
obs = self.func_env.observation(self.state)
|
||||||
|
info = self.func_env.state_info(self.state)
|
||||||
|
|
||||||
|
obs = _convert_jax_to_numpy(obs)
|
||||||
|
|
||||||
|
return obs, info
|
||||||
|
|
||||||
|
def step(self, action: ActType):
|
||||||
|
if self._is_box_action_space:
|
||||||
|
assert isinstance(self.action_space, gym.spaces.Box) # For typing
|
||||||
|
action = np.clip(action, self.action_space.low, self.action_space.high)
|
||||||
|
else: # Discrete
|
||||||
|
# For now we assume jax envs don't use complex spaces
|
||||||
|
err_msg = f"{action!r} ({type(action)}) invalid"
|
||||||
|
assert self.action_space.contains(action), err_msg
|
||||||
|
|
||||||
|
rng, self.rng = jrng.split(self.rng)
|
||||||
|
|
||||||
|
next_state = self.func_env.transition(self.state, action, rng)
|
||||||
|
observation = self.func_env.observation(self.state)
|
||||||
|
reward = self.func_env.reward(self.state, action, next_state)
|
||||||
|
terminated = self.func_env.terminal(next_state)
|
||||||
|
info = self.func_env.step_info(self.state, action, next_state)
|
||||||
|
self.state = next_state
|
||||||
|
|
||||||
|
observation = _convert_jax_to_numpy(observation)
|
||||||
|
|
||||||
|
return observation, float(reward), bool(terminated), False, info
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
if self.render_mode == "rgb_array":
|
||||||
|
self.render_state, image = self.func_env.render_image(
|
||||||
|
self.state, self.render_state
|
||||||
|
)
|
||||||
|
return image
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self.render_state is not None:
|
||||||
|
self.func_env.render_close(self.render_state)
|
||||||
|
self.render_state = None
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_jax_to_numpy(element: Any):
|
||||||
|
"""
|
||||||
|
Convert a jax observation/action to a numpy array, or a numpy-based container.
|
||||||
|
Currently required because all tests assume that stuff is in numpy arrays, hopefully will be removed soon.
|
||||||
|
"""
|
||||||
|
if isinstance(element, jnp.ndarray):
|
||||||
|
return np.asarray(element)
|
||||||
|
elif isinstance(element, tuple):
|
||||||
|
return tuple(_convert_jax_to_numpy(e) for e in element)
|
||||||
|
elif isinstance(element, list):
|
||||||
|
return [_convert_jax_to_numpy(e) for e in element]
|
||||||
|
elif isinstance(element, dict):
|
||||||
|
return {k: _convert_jax_to_numpy(v) for k, v in element.items()}
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Cannot convert {element} to numpy")
|
201
gymnasium/envs/phys2d/pendulum.py
Normal file
201
gymnasium/envs/phys2d/pendulum.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
"""
|
||||||
|
Implementation of a Jax-accelerated pendulum environment.
|
||||||
|
"""
|
||||||
|
from os import path
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
from jax.random import PRNGKey
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium.envs.phys2d.conversion import JaxEnv
|
||||||
|
from gymnasium.error import DependencyNotInstalled
|
||||||
|
from gymnasium.functional import ActType, FuncEnv, StateType
|
||||||
|
from gymnasium.utils import EzPickle
|
||||||
|
|
||||||
|
RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]] # type: ignore # noqa: F821
|
||||||
|
|
||||||
|
|
||||||
|
class PendulumF(FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]):
|
||||||
|
"""Pendulum but in jax and functional."""
|
||||||
|
|
||||||
|
max_speed = 8
|
||||||
|
max_torque = 2.0
|
||||||
|
dt = 0.05
|
||||||
|
g = 10.0
|
||||||
|
m = 1.0
|
||||||
|
l = 1.0
|
||||||
|
high_x = jnp.pi
|
||||||
|
high_y = 1.0
|
||||||
|
|
||||||
|
screen_dim = 500
|
||||||
|
|
||||||
|
observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(3,), dtype=np.float32)
|
||||||
|
action_space = gym.spaces.Box(-max_torque, max_torque, shape=(1,), dtype=np.float32)
|
||||||
|
|
||||||
|
def initial(self, rng: PRNGKey):
|
||||||
|
"""Initial state generation."""
|
||||||
|
high = jnp.array([self.high_x, self.high_y])
|
||||||
|
return jax.random.uniform(key=rng, minval=-high, maxval=high, shape=high.shape)
|
||||||
|
|
||||||
|
def transition(
|
||||||
|
self, state: jnp.ndarray, action: Union[int, jnp.ndarray], rng: None = None
|
||||||
|
) -> jnp.ndarray:
|
||||||
|
"""Pendulum transition."""
|
||||||
|
th, thdot = state # th := theta
|
||||||
|
u = action
|
||||||
|
|
||||||
|
g = self.g
|
||||||
|
m = self.m
|
||||||
|
l = self.l
|
||||||
|
dt = self.dt
|
||||||
|
|
||||||
|
u = jnp.clip(u, -self.max_torque, self.max_torque)[0]
|
||||||
|
|
||||||
|
newthdot = thdot + (3 * g / (2 * l) * jnp.sin(th) + 3.0 / (m * l**2) * u) * dt
|
||||||
|
newthdot = jnp.clip(newthdot, -self.max_speed, self.max_speed)
|
||||||
|
newth = th + newthdot * dt
|
||||||
|
|
||||||
|
new_state = jnp.array([newth, newthdot])
|
||||||
|
return new_state
|
||||||
|
|
||||||
|
def observation(self, state: jnp.ndarray) -> jnp.ndarray:
|
||||||
|
theta, thetadot = state
|
||||||
|
return jnp.array([jnp.cos(theta), jnp.sin(theta), thetadot])
|
||||||
|
|
||||||
|
def reward(self, state: StateType, action: ActType, next_state: StateType) -> float:
|
||||||
|
th, thdot = state # th := theta
|
||||||
|
u = action
|
||||||
|
|
||||||
|
u = jnp.clip(u, -self.max_torque, self.max_torque)[0]
|
||||||
|
|
||||||
|
th_normalized = ((th + jnp.pi) % (2 * jnp.pi)) - jnp.pi
|
||||||
|
costs = th_normalized**2 + 0.1 * thdot**2 + 0.001 * (u**2)
|
||||||
|
|
||||||
|
return -costs
|
||||||
|
|
||||||
|
def terminal(self, state: StateType) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def render_image(
|
||||||
|
self,
|
||||||
|
state: StateType,
|
||||||
|
render_state: Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]], # type: ignore # noqa: F821
|
||||||
|
) -> Tuple[RenderStateType, np.ndarray]:
|
||||||
|
try:
|
||||||
|
import pygame
|
||||||
|
from pygame import gfxdraw
|
||||||
|
except ImportError:
|
||||||
|
raise DependencyNotInstalled(
|
||||||
|
"pygame is not installed, run `pip install gymnasium[classic_control]`"
|
||||||
|
)
|
||||||
|
screen, clock, last_u = render_state
|
||||||
|
|
||||||
|
surf = pygame.Surface((self.screen_dim, self.screen_dim))
|
||||||
|
surf.fill((255, 255, 255))
|
||||||
|
|
||||||
|
bound = 2.2
|
||||||
|
scale = self.screen_dim / (bound * 2)
|
||||||
|
offset = self.screen_dim // 2
|
||||||
|
|
||||||
|
rod_length = 1 * scale
|
||||||
|
rod_width = 0.2 * scale
|
||||||
|
l, r, t, b = 0, rod_length, rod_width / 2, -rod_width / 2
|
||||||
|
coords = [(l, b), (l, t), (r, t), (r, b)]
|
||||||
|
transformed_coords = []
|
||||||
|
for c in coords:
|
||||||
|
c = pygame.math.Vector2(c).rotate_rad(state[0] + np.pi / 2)
|
||||||
|
c = (c[0] + offset, c[1] + offset)
|
||||||
|
transformed_coords.append(c)
|
||||||
|
gfxdraw.aapolygon(surf, transformed_coords, (204, 77, 77))
|
||||||
|
gfxdraw.filled_polygon(surf, transformed_coords, (204, 77, 77))
|
||||||
|
|
||||||
|
gfxdraw.aacircle(surf, offset, offset, int(rod_width / 2), (204, 77, 77))
|
||||||
|
gfxdraw.filled_circle(surf, offset, offset, int(rod_width / 2), (204, 77, 77))
|
||||||
|
|
||||||
|
rod_end = (rod_length, 0)
|
||||||
|
rod_end = pygame.math.Vector2(rod_end).rotate_rad(state[0] + np.pi / 2)
|
||||||
|
rod_end = (int(rod_end[0] + offset), int(rod_end[1] + offset))
|
||||||
|
gfxdraw.aacircle(
|
||||||
|
surf, rod_end[0], rod_end[1], int(rod_width / 2), (204, 77, 77)
|
||||||
|
)
|
||||||
|
gfxdraw.filled_circle(
|
||||||
|
surf, rod_end[0], rod_end[1], int(rod_width / 2), (204, 77, 77)
|
||||||
|
)
|
||||||
|
|
||||||
|
fname = path.join(path.dirname(__file__), "assets/clockwise.png")
|
||||||
|
img = pygame.image.load(fname)
|
||||||
|
if last_u is not None:
|
||||||
|
scale_img = pygame.transform.smoothscale(
|
||||||
|
img,
|
||||||
|
(scale * np.abs(last_u) / 2, scale * np.abs(last_u) / 2),
|
||||||
|
)
|
||||||
|
is_flip = bool(last_u > 0)
|
||||||
|
scale_img = pygame.transform.flip(scale_img, is_flip, True)
|
||||||
|
surf.blit(
|
||||||
|
scale_img,
|
||||||
|
(
|
||||||
|
offset - scale_img.get_rect().centerx,
|
||||||
|
offset - scale_img.get_rect().centery,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# drawing axle
|
||||||
|
gfxdraw.aacircle(surf, offset, offset, int(0.05 * scale), (0, 0, 0))
|
||||||
|
gfxdraw.filled_circle(surf, offset, offset, int(0.05 * scale), (0, 0, 0))
|
||||||
|
|
||||||
|
surf = pygame.transform.flip(surf, False, True)
|
||||||
|
screen.blit(surf, (0, 0))
|
||||||
|
|
||||||
|
return (screen, clock, last_u), np.transpose(
|
||||||
|
np.array(pygame.surfarray.pixels3d(screen)), axes=(1, 0, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
def render_init(
|
||||||
|
self, screen_width: int = 600, screen_height: int = 400
|
||||||
|
) -> RenderStateType:
|
||||||
|
try:
|
||||||
|
import pygame
|
||||||
|
except ImportError:
|
||||||
|
raise DependencyNotInstalled(
|
||||||
|
"pygame is not installed, run `pip install gymnasium[classic_control]`"
|
||||||
|
)
|
||||||
|
|
||||||
|
pygame.init()
|
||||||
|
screen = pygame.Surface((screen_width, screen_height))
|
||||||
|
clock = pygame.time.Clock()
|
||||||
|
|
||||||
|
return screen, clock, None
|
||||||
|
|
||||||
|
def render_close(self, render_state: RenderStateType) -> None:
|
||||||
|
try:
|
||||||
|
import pygame
|
||||||
|
except ImportError:
|
||||||
|
raise DependencyNotInstalled(
|
||||||
|
"pygame is not installed, run `pip install gymnasium[classic_control]`"
|
||||||
|
)
|
||||||
|
pygame.display.quit()
|
||||||
|
pygame.quit()
|
||||||
|
|
||||||
|
|
||||||
|
class PendulumJaxEnv(JaxEnv, EzPickle):
|
||||||
|
|
||||||
|
metadata = {"render_modes": ["rgb_array"], "render_fps": 30}
|
||||||
|
|
||||||
|
def __init__(self, render_mode: Optional[str] = None, **kwargs):
|
||||||
|
EzPickle.__init__(self, render_mode=render_mode, **kwargs)
|
||||||
|
env = PendulumF(**kwargs)
|
||||||
|
env.transform(jax.jit)
|
||||||
|
action_space = env.action_space
|
||||||
|
observation_space = env.observation_space
|
||||||
|
metadata = {"render_modes": ["rgb_array"], "render_fps": 30}
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
env,
|
||||||
|
observation_space=observation_space,
|
||||||
|
action_space=action_space,
|
||||||
|
metadata=metadata,
|
||||||
|
render_mode=render_mode,
|
||||||
|
)
|
96
gymnasium/functional.py
Normal file
96
gymnasium/functional.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
"""Base class and definitions for an alternative, functional backend for gym envs, particularly suitable for hardware accelerated and otherwise transformed environments."""
|
||||||
|
|
||||||
|
from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
StateType = TypeVar("StateType")
|
||||||
|
ActType = TypeVar("ActType")
|
||||||
|
ObsType = TypeVar("ObsType")
|
||||||
|
RewardType = TypeVar("RewardType")
|
||||||
|
TerminalType = TypeVar("TerminalType")
|
||||||
|
RenderStateType = TypeVar("RenderStateType")
|
||||||
|
|
||||||
|
|
||||||
|
class FuncEnv(
|
||||||
|
Generic[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType]
|
||||||
|
):
|
||||||
|
"""Base class (template) for functional envs.
|
||||||
|
|
||||||
|
This API is meant to be used in a stateless manner, with the environment state being passed around explicitly.
|
||||||
|
That being said, nothing here prevents users from using the environment statefully, it's just not recommended.
|
||||||
|
A functional env consists of the following functions (in this case, instance methods):
|
||||||
|
- initial: returns the initial state of the POMDP
|
||||||
|
- observation: returns the observation in a given state
|
||||||
|
- transition: returns the next state after taking an action in a given state
|
||||||
|
- reward: returns the reward for a given (state, action, next_state) tuple
|
||||||
|
- terminal: returns whether a given state is terminal
|
||||||
|
- state_info: optional, returns a dict of info about a given state
|
||||||
|
- step_info: optional, returns a dict of info about a given (state, action, next_state) tuple
|
||||||
|
|
||||||
|
The class-based structure serves the purpose of allowing environment constants to be defined in the class,
|
||||||
|
and then using them by name in the code itself.
|
||||||
|
|
||||||
|
For the moment, this is predominantly for internal use. This API is likely to change, but in the future
|
||||||
|
we intend to flesh it out and officially expose it to end users.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, options: Optional[Dict[str, Any]] = None):
|
||||||
|
"""Initialize the environment constants."""
|
||||||
|
self.__dict__.update(options or {})
|
||||||
|
|
||||||
|
def initial(self, rng: Any) -> StateType:
|
||||||
|
"""Initial state."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def observation(self, state: StateType) -> ObsType:
|
||||||
|
"""Observation."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def transition(self, state: StateType, action: ActType, rng: Any) -> StateType:
|
||||||
|
"""Transition."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def reward(
|
||||||
|
self, state: StateType, action: ActType, next_state: StateType
|
||||||
|
) -> RewardType:
|
||||||
|
"""Reward."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def terminal(self, state: StateType) -> TerminalType:
|
||||||
|
"""Terminal state."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def state_info(self, state: StateType) -> dict:
|
||||||
|
"""Info dict about a single state."""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def step_info(
|
||||||
|
self, state: StateType, action: ActType, next_state: StateType
|
||||||
|
) -> dict:
|
||||||
|
"""Info dict about a full transition."""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def transform(self, func: Callable[[Callable], Callable]):
|
||||||
|
"""Functional transformations."""
|
||||||
|
self.initial = func(self.initial)
|
||||||
|
self.transition = func(self.transition)
|
||||||
|
self.observation = func(self.observation)
|
||||||
|
self.reward = func(self.reward)
|
||||||
|
self.terminal = func(self.terminal)
|
||||||
|
self.state_info = func(self.state_info)
|
||||||
|
self.step_info = func(self.step_info)
|
||||||
|
|
||||||
|
def render_image(
|
||||||
|
self, state: StateType, render_state: RenderStateType
|
||||||
|
) -> Tuple[RenderStateType, np.ndarray]:
|
||||||
|
"""Show the state."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def render_init(self, **kwargs) -> RenderStateType:
|
||||||
|
"""Initialize the render state."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def render_close(self, render_state: RenderStateType):
|
||||||
|
"""Close the render state."""
|
||||||
|
raise NotImplementedError
|
@@ -12,4 +12,4 @@ pygame==2.1.0
|
|||||||
ale-py~=0.8.0
|
ale-py~=0.8.0
|
||||||
mujoco==2.2
|
mujoco==2.2
|
||||||
mujoco_py<2.2,>=2.1
|
mujoco_py<2.2,>=2.1
|
||||||
imageio>=2.14.1
|
imageio>=2.14.1
|
||||||
|
1
setup.py
1
setup.py
@@ -40,6 +40,7 @@ extras = {
|
|||||||
"mujoco_py": ["mujoco_py<2.2,>=2.1"],
|
"mujoco_py": ["mujoco_py<2.2,>=2.1"],
|
||||||
"mujoco": ["mujoco==2.2", "imageio>=2.14.1"],
|
"mujoco": ["mujoco==2.2", "imageio>=2.14.1"],
|
||||||
"toy_text": ["pygame==2.1.0"],
|
"toy_text": ["pygame==2.1.0"],
|
||||||
|
"jax": ["jax==0.3.20", "jaxlib==0.3.20"],
|
||||||
"other": ["lz4>=3.1.0", "opencv-python>=3.0", "matplotlib>=3.0", "moviepy>=1.0.0"],
|
"other": ["lz4>=3.1.0", "opencv-python>=3.0", "matplotlib>=3.0", "moviepy>=1.0.0"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
0
tests/envs/functional/__init__.py
Normal file
0
tests/envs/functional/__init__.py
Normal file
56
tests/envs/functional/test_core.py
Normal file
56
tests/envs/functional/test_core.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from gymnasium.functional import FuncEnv
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnv(FuncEnv):
|
||||||
|
def __init__(self, options: Optional[Dict[str, Any]] = None):
|
||||||
|
super().__init__(options)
|
||||||
|
|
||||||
|
def initial(self, rng: Any) -> np.ndarray:
|
||||||
|
return np.array([0, 0], dtype=np.float32)
|
||||||
|
|
||||||
|
def observation(self, state: np.ndarray) -> np.ndarray:
|
||||||
|
return state
|
||||||
|
|
||||||
|
def transition(self, state: np.ndarray, action: int, rng: None) -> np.ndarray:
|
||||||
|
return state + np.array([0, action], dtype=np.float32)
|
||||||
|
|
||||||
|
def reward(self, state: np.ndarray, action: int, next_state: np.ndarray) -> float:
|
||||||
|
return 1.0 if next_state[1] > 0 else 0.0
|
||||||
|
|
||||||
|
def terminal(self, state: np.ndarray) -> bool:
|
||||||
|
return state[1] > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_api():
|
||||||
|
env = TestEnv()
|
||||||
|
state = env.initial(None)
|
||||||
|
obs = env.observation(state)
|
||||||
|
assert state.shape == (2,)
|
||||||
|
assert state.dtype == np.float32
|
||||||
|
assert obs.shape == (2,)
|
||||||
|
assert obs.dtype == np.float32
|
||||||
|
assert np.allclose(obs, state)
|
||||||
|
|
||||||
|
actions = [-1, -2, -5, 3, 5, 2]
|
||||||
|
for i, action in enumerate(actions):
|
||||||
|
next_state = env.transition(state, action, None)
|
||||||
|
assert next_state.shape == (2,)
|
||||||
|
assert next_state.dtype == np.float32
|
||||||
|
assert np.allclose(next_state, state + np.array([0, action]))
|
||||||
|
|
||||||
|
observation = env.observation(next_state)
|
||||||
|
assert observation.shape == (2,)
|
||||||
|
assert observation.dtype == np.float32
|
||||||
|
assert np.allclose(observation, next_state)
|
||||||
|
|
||||||
|
reward = env.reward(state, action, next_state)
|
||||||
|
assert reward == (1.0 if next_state[1] > 0 else 0.0)
|
||||||
|
|
||||||
|
terminal = env.terminal(next_state)
|
||||||
|
assert terminal == (i == 5) # terminal state is in the final action
|
||||||
|
|
||||||
|
state = next_state
|
105
tests/envs/functional/test_jax.py
Normal file
105
tests/envs/functional/test_jax.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jax.random as jrng
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from gymnasium.envs.phys2d.cartpole import CartPoleF # noqa: E402
|
||||||
|
from gymnasium.envs.phys2d.pendulum import PendulumF # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("env_class", [CartPoleF, PendulumF])
|
||||||
|
def test_normal(env_class):
|
||||||
|
env = env_class()
|
||||||
|
rng = jrng.PRNGKey(0)
|
||||||
|
|
||||||
|
state = env.initial(rng)
|
||||||
|
env.action_space.seed(0)
|
||||||
|
|
||||||
|
for t in range(10):
|
||||||
|
obs = env.observation(state)
|
||||||
|
action = env.action_space.sample()
|
||||||
|
next_state = env.transition(state, action, None)
|
||||||
|
reward = env.reward(state, action, next_state)
|
||||||
|
terminal = env.terminal(next_state)
|
||||||
|
|
||||||
|
assert next_state.shape == state.shape
|
||||||
|
try:
|
||||||
|
float(reward)
|
||||||
|
except ValueError:
|
||||||
|
pytest.fail("Reward is not castable to float")
|
||||||
|
try:
|
||||||
|
bool(terminal)
|
||||||
|
except ValueError:
|
||||||
|
pytest.fail("Terminal is not castable to bool")
|
||||||
|
|
||||||
|
assert next_state.dtype == jnp.float32
|
||||||
|
assert isinstance(obs, jnp.ndarray)
|
||||||
|
assert obs.dtype == jnp.float32
|
||||||
|
|
||||||
|
state = next_state
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("env_class", [CartPoleF, PendulumF])
|
||||||
|
def test_jit(env_class):
|
||||||
|
env = env_class()
|
||||||
|
rng = jrng.PRNGKey(0)
|
||||||
|
|
||||||
|
env.transform(jax.jit)
|
||||||
|
state = env.initial(rng)
|
||||||
|
env.action_space.seed(0)
|
||||||
|
|
||||||
|
for t in range(10):
|
||||||
|
obs = env.observation(state)
|
||||||
|
action = env.action_space.sample()
|
||||||
|
next_state = env.transition(state, action, None)
|
||||||
|
reward = env.reward(state, action, next_state)
|
||||||
|
terminal = env.terminal(next_state)
|
||||||
|
|
||||||
|
assert next_state.shape == state.shape
|
||||||
|
try:
|
||||||
|
float(reward)
|
||||||
|
except ValueError:
|
||||||
|
pytest.fail("Reward is not castable to float")
|
||||||
|
try:
|
||||||
|
bool(terminal)
|
||||||
|
except ValueError:
|
||||||
|
pytest.fail("Terminal is not castable to bool")
|
||||||
|
|
||||||
|
assert next_state.dtype == jnp.float32
|
||||||
|
assert isinstance(obs, jnp.ndarray)
|
||||||
|
assert obs.dtype == jnp.float32
|
||||||
|
|
||||||
|
state = next_state
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("env_class", [CartPoleF, PendulumF])
|
||||||
|
def test_vmap(env_class):
|
||||||
|
env = env_class()
|
||||||
|
num_envs = 10
|
||||||
|
rng = jrng.split(jrng.PRNGKey(0), num_envs)
|
||||||
|
|
||||||
|
env.transform(jax.vmap)
|
||||||
|
env.transform(jax.jit)
|
||||||
|
state = env.initial(rng)
|
||||||
|
env.action_space.seed(0)
|
||||||
|
|
||||||
|
for t in range(10):
|
||||||
|
obs = env.observation(state)
|
||||||
|
action = jnp.array([env.action_space.sample() for _ in range(num_envs)])
|
||||||
|
# if isinstance(env.action_space, Discrete):
|
||||||
|
# action = action.reshape((num_envs, 1))
|
||||||
|
next_state = env.transition(state, action, None)
|
||||||
|
terminal = env.terminal(next_state)
|
||||||
|
reward = env.reward(state, action, next_state)
|
||||||
|
|
||||||
|
assert next_state.shape == state.shape
|
||||||
|
assert next_state.dtype == jnp.float32
|
||||||
|
assert reward.shape == (num_envs,)
|
||||||
|
assert reward.dtype == jnp.float32
|
||||||
|
assert terminal.shape == (num_envs,)
|
||||||
|
assert terminal.dtype == np.bool
|
||||||
|
assert isinstance(obs, jnp.ndarray)
|
||||||
|
assert obs.dtype == jnp.float32
|
||||||
|
|
||||||
|
state = next_state
|
@@ -171,7 +171,6 @@ def test_render_modes(spec):
|
|||||||
env = spec.make()
|
env = spec.make()
|
||||||
|
|
||||||
assert "rgb_array" in env.metadata["render_modes"]
|
assert "rgb_array" in env.metadata["render_modes"]
|
||||||
assert "human" in env.metadata["render_modes"]
|
|
||||||
|
|
||||||
for mode in env.metadata["render_modes"]:
|
for mode in env.metadata["render_modes"]:
|
||||||
if mode != "human":
|
if mode != "human":
|
||||||
|
@@ -18,173 +18,3 @@ CartPole-v1
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
assert out == correct_out
|
assert out == correct_out
|
||||||
|
|
||||||
|
|
||||||
def test_pprint_registry():
|
|
||||||
"""Testing the default registry, with no changes."""
|
|
||||||
out = gym.pprint_registry(disable_print=True)
|
|
||||||
|
|
||||||
correct_out = """===== classic_control =====
|
|
||||||
Acrobot-v1
|
|
||||||
CartPole-v0
|
|
||||||
CartPole-v1
|
|
||||||
MountainCar-v0
|
|
||||||
MountainCarContinuous-v0
|
|
||||||
Pendulum-v1
|
|
||||||
|
|
||||||
===== box2d =====
|
|
||||||
BipedalWalker-v3
|
|
||||||
BipedalWalkerHardcore-v3
|
|
||||||
CarRacing-v2
|
|
||||||
LunarLander-v2
|
|
||||||
LunarLanderContinuous-v2
|
|
||||||
|
|
||||||
===== toy_text =====
|
|
||||||
Blackjack-v1
|
|
||||||
CliffWalking-v0
|
|
||||||
FrozenLake-v1
|
|
||||||
FrozenLake8x8-v1
|
|
||||||
Taxi-v3
|
|
||||||
|
|
||||||
===== mujoco =====
|
|
||||||
Ant-v2 Ant-v3 Ant-v4
|
|
||||||
HalfCheetah-v2 HalfCheetah-v3 HalfCheetah-v4
|
|
||||||
Hopper-v2 Hopper-v3 Hopper-v4
|
|
||||||
Humanoid-v2 Humanoid-v3 Humanoid-v4
|
|
||||||
HumanoidStandup-v2 HumanoidStandup-v4 InvertedDoublePendulum-v2
|
|
||||||
InvertedDoublePendulum-v4 InvertedPendulum-v2 InvertedPendulum-v4
|
|
||||||
Pusher-v2 Pusher-v4 Reacher-v2
|
|
||||||
Reacher-v4 Swimmer-v2 Swimmer-v3
|
|
||||||
Swimmer-v4 Walker2d-v2 Walker2d-v3
|
|
||||||
Walker2d-v4
|
|
||||||
|
|
||||||
===== external =====
|
|
||||||
GymV26Environment-v0
|
|
||||||
|
|
||||||
===== utils_envs =====
|
|
||||||
RegisterDuringMakeEnv-v0
|
|
||||||
test.ArgumentEnv-v0
|
|
||||||
test.OrderlessArgumentEnv-v0
|
|
||||||
|
|
||||||
===== test =====
|
|
||||||
test/NoHuman-v0
|
|
||||||
test/NoHumanNoRGB-v0
|
|
||||||
test/NoHumanOldAPI-v0
|
|
||||||
|
|
||||||
"""
|
|
||||||
assert out == correct_out
|
|
||||||
|
|
||||||
|
|
||||||
def test_pprint_registry_exclude_namespaces():
|
|
||||||
"""Testing the default registry, with no changes."""
|
|
||||||
out = gym.pprint_registry(
|
|
||||||
max_rows=20, exclude_namespaces=["classic_control"], disable_print=True
|
|
||||||
)
|
|
||||||
|
|
||||||
correct_out = """===== box2d =====
|
|
||||||
BipedalWalker-v3
|
|
||||||
BipedalWalkerHardcore-v3
|
|
||||||
CarRacing-v2
|
|
||||||
LunarLander-v2
|
|
||||||
LunarLanderContinuous-v2
|
|
||||||
|
|
||||||
===== toy_text =====
|
|
||||||
Blackjack-v1
|
|
||||||
CliffWalking-v0
|
|
||||||
FrozenLake-v1
|
|
||||||
FrozenLake8x8-v1
|
|
||||||
Taxi-v3
|
|
||||||
|
|
||||||
===== mujoco =====
|
|
||||||
Ant-v2 Ant-v3
|
|
||||||
Ant-v4 HalfCheetah-v2
|
|
||||||
HalfCheetah-v3 HalfCheetah-v4
|
|
||||||
Hopper-v2 Hopper-v3
|
|
||||||
Hopper-v4 Humanoid-v2
|
|
||||||
Humanoid-v3 Humanoid-v4
|
|
||||||
HumanoidStandup-v2 HumanoidStandup-v4
|
|
||||||
InvertedDoublePendulum-v2 InvertedDoublePendulum-v4
|
|
||||||
InvertedPendulum-v2 InvertedPendulum-v4
|
|
||||||
Pusher-v2 Pusher-v4
|
|
||||||
Reacher-v2 Reacher-v4
|
|
||||||
Swimmer-v2 Swimmer-v3
|
|
||||||
Swimmer-v4 Walker2d-v2
|
|
||||||
Walker2d-v3 Walker2d-v4
|
|
||||||
|
|
||||||
===== external =====
|
|
||||||
GymV26Environment-v0
|
|
||||||
|
|
||||||
===== utils_envs =====
|
|
||||||
RegisterDuringMakeEnv-v0
|
|
||||||
test.ArgumentEnv-v0
|
|
||||||
test.OrderlessArgumentEnv-v0
|
|
||||||
|
|
||||||
===== test =====
|
|
||||||
test/NoHuman-v0
|
|
||||||
test/NoHumanNoRGB-v0
|
|
||||||
test/NoHumanOldAPI-v0
|
|
||||||
|
|
||||||
"""
|
|
||||||
assert out == correct_out
|
|
||||||
|
|
||||||
|
|
||||||
def test_pprint_registry_no_entry_point():
|
|
||||||
"""Test registry if there is environment with no entry point."""
|
|
||||||
|
|
||||||
gym.register("NoNamespaceEnv", "no-entry-point")
|
|
||||||
out = gym.pprint_registry(disable_print=True)
|
|
||||||
|
|
||||||
correct_out = """===== classic_control =====
|
|
||||||
Acrobot-v1
|
|
||||||
CartPole-v0
|
|
||||||
CartPole-v1
|
|
||||||
MountainCar-v0
|
|
||||||
MountainCarContinuous-v0
|
|
||||||
Pendulum-v1
|
|
||||||
|
|
||||||
===== box2d =====
|
|
||||||
BipedalWalker-v3
|
|
||||||
BipedalWalkerHardcore-v3
|
|
||||||
CarRacing-v2
|
|
||||||
LunarLander-v2
|
|
||||||
LunarLanderContinuous-v2
|
|
||||||
|
|
||||||
===== toy_text =====
|
|
||||||
Blackjack-v1
|
|
||||||
CliffWalking-v0
|
|
||||||
FrozenLake-v1
|
|
||||||
FrozenLake8x8-v1
|
|
||||||
Taxi-v3
|
|
||||||
|
|
||||||
===== mujoco =====
|
|
||||||
Ant-v2 Ant-v3 Ant-v4
|
|
||||||
HalfCheetah-v2 HalfCheetah-v3 HalfCheetah-v4
|
|
||||||
Hopper-v2 Hopper-v3 Hopper-v4
|
|
||||||
Humanoid-v2 Humanoid-v3 Humanoid-v4
|
|
||||||
HumanoidStandup-v2 HumanoidStandup-v4 InvertedDoublePendulum-v2
|
|
||||||
InvertedDoublePendulum-v4 InvertedPendulum-v2 InvertedPendulum-v4
|
|
||||||
Pusher-v2 Pusher-v4 Reacher-v2
|
|
||||||
Reacher-v4 Swimmer-v2 Swimmer-v3
|
|
||||||
Swimmer-v4 Walker2d-v2 Walker2d-v3
|
|
||||||
Walker2d-v4
|
|
||||||
|
|
||||||
===== external =====
|
|
||||||
GymV26Environment-v0
|
|
||||||
|
|
||||||
===== utils_envs =====
|
|
||||||
RegisterDuringMakeEnv-v0
|
|
||||||
test.ArgumentEnv-v0
|
|
||||||
test.OrderlessArgumentEnv-v0
|
|
||||||
|
|
||||||
===== test =====
|
|
||||||
test/NoHuman-v0
|
|
||||||
test/NoHumanNoRGB-v0
|
|
||||||
test/NoHumanOldAPI-v0
|
|
||||||
|
|
||||||
===== NoNamespaceEnv =====
|
|
||||||
NoNamespaceEnv
|
|
||||||
|
|
||||||
"""
|
|
||||||
assert out == correct_out
|
|
||||||
|
|
||||||
del gym.envs.registry["NoNamespaceEnv"]
|
|
||||||
|
@@ -2,7 +2,7 @@ import pytest
|
|||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
|
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
|
||||||
from gymnasium.wrappers import OrderEnforcing, TimeLimit, TransformObservation
|
from gymnasium.wrappers import TimeLimit, TransformObservation
|
||||||
from gymnasium.wrappers.env_checker import PassiveEnvChecker
|
from gymnasium.wrappers.env_checker import PassiveEnvChecker
|
||||||
from tests.wrappers.utils import has_wrapper
|
from tests.wrappers.utils import has_wrapper
|
||||||
|
|
||||||
@@ -39,8 +39,6 @@ def test_vector_make_wrappers():
|
|||||||
sub_env = env.envs[0]
|
sub_env = env.envs[0]
|
||||||
assert isinstance(sub_env, gym.Env)
|
assert isinstance(sub_env, gym.Env)
|
||||||
assert sub_env.spec is not None
|
assert sub_env.spec is not None
|
||||||
if sub_env.spec.order_enforce:
|
|
||||||
assert has_wrapper(sub_env, OrderEnforcing)
|
|
||||||
if sub_env.spec.max_episode_steps is not None:
|
if sub_env.spec.max_episode_steps is not None:
|
||||||
assert has_wrapper(sub_env, TimeLimit)
|
assert has_wrapper(sub_env, TimeLimit)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user