Functional API and proof-of-concept jax classic-control envs (#25) (#145)

This commit is contained in:
Ariel Kwiatkowski
2022-11-18 22:25:33 +01:00
committed by GitHub
parent a93da8f271
commit 34dfc9a728
15 changed files with 860 additions and 175 deletions

View File

@@ -49,6 +49,30 @@ register(
max_episode_steps=500,
)
# Phys2d (jax classic control)
# ----------------------------------------
register(
id="CartPoleJax-v0",
entry_point="gymnasium.envs.phys2d.cartpole:CartPoleJaxEnv",
max_episode_steps=200,
reward_threshold=195.0,
)
register(
id="CartPoleJax-v1",
entry_point="gymnasium.envs.phys2d.cartpole:CartPoleJaxEnv",
max_episode_steps=500,
reward_threshold=475.0,
)
register(
id="PendulumJax-v0",
entry_point="gymnasium.envs.phys2d.pendulum:PendulumJaxEnv",
max_episode_steps=200,
)
# Box2d
# ----------------------------------------

View File

@@ -0,0 +1,2 @@
from gymnasium.envs.phys2d.cartpole import CartPoleF
from gymnasium.envs.phys2d.pendulum import PendulumF

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.8 KiB

View File

@@ -0,0 +1,252 @@
"""
Implementation of a Jax-accelerated cartpole environment.
"""
from typing import Optional, Tuple, Union
import jax
import jax.numpy as jnp
import numpy as np
from jax.random import PRNGKey
import gymnasium as gym
from gymnasium.envs.phys2d.conversion import JaxEnv
from gymnasium.error import DependencyNotInstalled
from gymnasium.functional import ActType, FuncEnv, StateType
from gymnasium.utils import EzPickle
RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock"] # type: ignore # noqa: F821
class CartPoleF(FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]):
"""Cartpole but in jax and functional.
Example usage:
```
import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(0)
env = CartPole({"x_init": 0.5})
state = env.initial(key)
print(state)
print(env.step(state, 0))
env.transform(jax.jit)
state = env.initial(key)
print(state)
print(env.step(state, 0))
vkey = jax.random.split(key, 10)
env.transform(jax.vmap)
vstate = env.initial(vkey)
print(vstate)
print(env.step(vstate, jnp.array([0 for _ in range(10)])))
```
"""
gravity = 9.8
masscart = 1.0
masspole = 0.1
total_mass = masspole + masscart
length = 0.5
polemass_length = masspole + length
force_mag = 10.0
tau = 0.02
theta_threshold_radians = 12 * 2 * np.pi / 360
x_threshold = 2.4
x_init = 0.05
screen_width = 600
screen_height = 400
observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(4,), dtype=np.float32)
action_space = gym.spaces.Discrete(2)
def initial(self, rng: PRNGKey):
"""Initial state generation."""
return jax.random.uniform(
key=rng, minval=-self.x_init, maxval=self.x_init, shape=(4,)
)
def transition(
self, state: jnp.ndarray, action: Union[int, jnp.ndarray], rng: None = None
) -> StateType:
"""Cartpole transition."""
x, x_dot, theta, theta_dot = state
force = jnp.sign(action - 0.5) * self.force_mag
costheta = jnp.cos(theta)
sintheta = jnp.sin(theta)
# For the interested reader:
# https://coneural.org/florian/papers/05_cart_pole.pdf
temp = (
force + self.polemass_length * theta_dot**2 * sintheta
) / self.total_mass
thetaacc = (self.gravity * sintheta - costheta * temp) / (
self.length * (4.0 / 3.0 - self.masspole * costheta**2 / self.total_mass)
)
xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass
x = x + self.tau * x_dot
x_dot = x_dot + self.tau * xacc
theta = theta + self.tau * theta_dot
theta_dot = theta_dot + self.tau * thetaacc
state = jnp.array((x, x_dot, theta, theta_dot), dtype=jnp.float32)
return state
def observation(self, state: jnp.ndarray) -> jnp.ndarray:
"""Cartpole observation."""
return state
def terminal(self, state: jnp.ndarray) -> jnp.ndarray:
x, _, theta, _ = state
terminated = (
(x < -self.x_threshold)
| (x > self.x_threshold)
| (theta < -self.theta_threshold_radians)
| (theta > self.theta_threshold_radians)
)
return terminated
def reward(
self, state: StateType, action: ActType, next_state: StateType
) -> jnp.ndarray:
x, _, theta, _ = state
terminated = (
(x < -self.x_threshold)
| (x > self.x_threshold)
| (theta < -self.theta_threshold_radians)
| (theta > self.theta_threshold_radians)
)
reward = jax.lax.cond(terminated, lambda: 0.0, lambda: 1.0)
return reward
def render_image(
self,
state: StateType,
render_state: RenderStateType,
) -> Tuple[RenderStateType, np.ndarray]:
try:
import pygame
from pygame import gfxdraw
except ImportError:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gymnasium[classic_control]`"
)
screen, clock = render_state
world_width = self.x_threshold * 2
scale = self.screen_width / world_width
polewidth = 10.0
polelen = scale * (2 * self.length)
cartwidth = 50.0
cartheight = 30.0
x = state
surf = pygame.Surface((self.screen_width, self.screen_height))
surf.fill((255, 255, 255))
l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2
axleoffset = cartheight / 4.0
cartx = x[0] * scale + self.screen_width / 2.0 # MIDDLE OF CART
carty = 100 # TOP OF CART
cart_coords = [(l, b), (l, t), (r, t), (r, b)]
cart_coords = [(c[0] + cartx, c[1] + carty) for c in cart_coords]
gfxdraw.aapolygon(surf, cart_coords, (0, 0, 0))
gfxdraw.filled_polygon(surf, cart_coords, (0, 0, 0))
l, r, t, b = (
-polewidth / 2,
polewidth / 2,
polelen - polewidth / 2,
-polewidth / 2,
)
pole_coords = []
for coord in [(l, b), (l, t), (r, t), (r, b)]:
coord = pygame.math.Vector2(coord).rotate_rad(-x[2])
coord = (coord[0] + cartx, coord[1] + carty + axleoffset)
pole_coords.append(coord)
gfxdraw.aapolygon(surf, pole_coords, (202, 152, 101))
gfxdraw.filled_polygon(surf, pole_coords, (202, 152, 101))
gfxdraw.aacircle(
surf,
int(cartx),
int(carty + axleoffset),
int(polewidth / 2),
(129, 132, 203),
)
gfxdraw.filled_circle(
surf,
int(cartx),
int(carty + axleoffset),
int(polewidth / 2),
(129, 132, 203),
)
gfxdraw.hline(surf, 0, self.screen_width, carty, (0, 0, 0))
surf = pygame.transform.flip(surf, False, True)
screen.blit(surf, (0, 0))
return (screen, clock), np.transpose(
np.array(pygame.surfarray.pixels3d(screen)), axes=(1, 0, 2)
)
def render_init(
self, screen_width: int = 600, screen_height: int = 400
) -> RenderStateType:
try:
import pygame
except ImportError:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gymnasium[classic_control]`"
)
pygame.init()
screen = pygame.Surface((screen_width, screen_height))
clock = pygame.time.Clock()
return screen, clock
def render_close(self, render_state: RenderStateType) -> None:
try:
import pygame
except ImportError:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gymnasium[classic_control]`"
)
pygame.display.quit()
pygame.quit()
class CartPoleJaxEnv(JaxEnv, EzPickle):
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
def __init__(self, render_mode: Optional[str] = None, **kwargs):
EzPickle.__init__(self, render_mode=render_mode, **kwargs)
env = CartPoleF(**kwargs)
env.transform(jax.jit)
action_space = env.action_space
observation_space = env.observation_space
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
super().__init__(
env,
observation_space=observation_space,
action_space=action_space,
metadata=metadata,
render_mode=render_mode,
)

View File

@@ -0,0 +1,121 @@
from typing import Any, Dict, Optional, Tuple
import jax.numpy as jnp
import jax.random as jrng
import numpy as np
import gymnasium as gym
from gymnasium import Space
from gymnasium.envs.registration import EnvSpec
from gymnasium.functional import ActType, FuncEnv, StateType
from gymnasium.utils import seeding
class JaxEnv(gym.Env):
"""
A conversion layer for numpy-based environments.
"""
state: StateType
rng: jrng.PRNGKey
def __init__(
self,
func_env: FuncEnv,
observation_space: Space,
action_space: Space,
metadata: Optional[Dict[str, Any]] = None,
render_mode: Optional[str] = None,
reward_range: Tuple[float, float] = (-float("inf"), float("inf")),
spec: Optional[EnvSpec] = None,
):
"""Initialize the environment from a FuncEnv."""
if metadata is None:
metadata = {}
self.func_env = func_env
self.observation_space = observation_space
self.action_space = action_space
self.metadata = metadata
self.render_mode = render_mode
self.reward_range = reward_range
self.spec = spec
self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
if self.render_mode == "rgb_array":
self.render_state = self.func_env.render_init()
else:
self.render_state = None
np_random, _ = seeding.np_random()
seed = np_random.integers(0, 2**32 - 1, dtype="uint32")
self.rng = jrng.PRNGKey(seed)
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
if seed is not None:
self.rng = jrng.PRNGKey(seed)
rng, self.rng = jrng.split(self.rng)
self.state = self.func_env.initial(rng=rng)
obs = self.func_env.observation(self.state)
info = self.func_env.state_info(self.state)
obs = _convert_jax_to_numpy(obs)
return obs, info
def step(self, action: ActType):
if self._is_box_action_space:
assert isinstance(self.action_space, gym.spaces.Box) # For typing
action = np.clip(action, self.action_space.low, self.action_space.high)
else: # Discrete
# For now we assume jax envs don't use complex spaces
err_msg = f"{action!r} ({type(action)}) invalid"
assert self.action_space.contains(action), err_msg
rng, self.rng = jrng.split(self.rng)
next_state = self.func_env.transition(self.state, action, rng)
observation = self.func_env.observation(self.state)
reward = self.func_env.reward(self.state, action, next_state)
terminated = self.func_env.terminal(next_state)
info = self.func_env.step_info(self.state, action, next_state)
self.state = next_state
observation = _convert_jax_to_numpy(observation)
return observation, float(reward), bool(terminated), False, info
def render(self):
if self.render_mode == "rgb_array":
self.render_state, image = self.func_env.render_image(
self.state, self.render_state
)
return image
else:
raise NotImplementedError
def close(self):
if self.render_state is not None:
self.func_env.render_close(self.render_state)
self.render_state = None
def _convert_jax_to_numpy(element: Any):
"""
Convert a jax observation/action to a numpy array, or a numpy-based container.
Currently required because all tests assume that stuff is in numpy arrays, hopefully will be removed soon.
"""
if isinstance(element, jnp.ndarray):
return np.asarray(element)
elif isinstance(element, tuple):
return tuple(_convert_jax_to_numpy(e) for e in element)
elif isinstance(element, list):
return [_convert_jax_to_numpy(e) for e in element]
elif isinstance(element, dict):
return {k: _convert_jax_to_numpy(v) for k, v in element.items()}
else:
raise TypeError(f"Cannot convert {element} to numpy")

View File

@@ -0,0 +1,201 @@
"""
Implementation of a Jax-accelerated pendulum environment.
"""
from os import path
from typing import Optional, Tuple, Union
import jax
import jax.numpy as jnp
import numpy as np
from jax.random import PRNGKey
import gymnasium as gym
from gymnasium.envs.phys2d.conversion import JaxEnv
from gymnasium.error import DependencyNotInstalled
from gymnasium.functional import ActType, FuncEnv, StateType
from gymnasium.utils import EzPickle
RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]] # type: ignore # noqa: F821
class PendulumF(FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]):
"""Pendulum but in jax and functional."""
max_speed = 8
max_torque = 2.0
dt = 0.05
g = 10.0
m = 1.0
l = 1.0
high_x = jnp.pi
high_y = 1.0
screen_dim = 500
observation_space = gym.spaces.Box(-np.inf, np.inf, shape=(3,), dtype=np.float32)
action_space = gym.spaces.Box(-max_torque, max_torque, shape=(1,), dtype=np.float32)
def initial(self, rng: PRNGKey):
"""Initial state generation."""
high = jnp.array([self.high_x, self.high_y])
return jax.random.uniform(key=rng, minval=-high, maxval=high, shape=high.shape)
def transition(
self, state: jnp.ndarray, action: Union[int, jnp.ndarray], rng: None = None
) -> jnp.ndarray:
"""Pendulum transition."""
th, thdot = state # th := theta
u = action
g = self.g
m = self.m
l = self.l
dt = self.dt
u = jnp.clip(u, -self.max_torque, self.max_torque)[0]
newthdot = thdot + (3 * g / (2 * l) * jnp.sin(th) + 3.0 / (m * l**2) * u) * dt
newthdot = jnp.clip(newthdot, -self.max_speed, self.max_speed)
newth = th + newthdot * dt
new_state = jnp.array([newth, newthdot])
return new_state
def observation(self, state: jnp.ndarray) -> jnp.ndarray:
theta, thetadot = state
return jnp.array([jnp.cos(theta), jnp.sin(theta), thetadot])
def reward(self, state: StateType, action: ActType, next_state: StateType) -> float:
th, thdot = state # th := theta
u = action
u = jnp.clip(u, -self.max_torque, self.max_torque)[0]
th_normalized = ((th + jnp.pi) % (2 * jnp.pi)) - jnp.pi
costs = th_normalized**2 + 0.1 * thdot**2 + 0.001 * (u**2)
return -costs
def terminal(self, state: StateType) -> bool:
return False
def render_image(
self,
state: StateType,
render_state: Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]], # type: ignore # noqa: F821
) -> Tuple[RenderStateType, np.ndarray]:
try:
import pygame
from pygame import gfxdraw
except ImportError:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gymnasium[classic_control]`"
)
screen, clock, last_u = render_state
surf = pygame.Surface((self.screen_dim, self.screen_dim))
surf.fill((255, 255, 255))
bound = 2.2
scale = self.screen_dim / (bound * 2)
offset = self.screen_dim // 2
rod_length = 1 * scale
rod_width = 0.2 * scale
l, r, t, b = 0, rod_length, rod_width / 2, -rod_width / 2
coords = [(l, b), (l, t), (r, t), (r, b)]
transformed_coords = []
for c in coords:
c = pygame.math.Vector2(c).rotate_rad(state[0] + np.pi / 2)
c = (c[0] + offset, c[1] + offset)
transformed_coords.append(c)
gfxdraw.aapolygon(surf, transformed_coords, (204, 77, 77))
gfxdraw.filled_polygon(surf, transformed_coords, (204, 77, 77))
gfxdraw.aacircle(surf, offset, offset, int(rod_width / 2), (204, 77, 77))
gfxdraw.filled_circle(surf, offset, offset, int(rod_width / 2), (204, 77, 77))
rod_end = (rod_length, 0)
rod_end = pygame.math.Vector2(rod_end).rotate_rad(state[0] + np.pi / 2)
rod_end = (int(rod_end[0] + offset), int(rod_end[1] + offset))
gfxdraw.aacircle(
surf, rod_end[0], rod_end[1], int(rod_width / 2), (204, 77, 77)
)
gfxdraw.filled_circle(
surf, rod_end[0], rod_end[1], int(rod_width / 2), (204, 77, 77)
)
fname = path.join(path.dirname(__file__), "assets/clockwise.png")
img = pygame.image.load(fname)
if last_u is not None:
scale_img = pygame.transform.smoothscale(
img,
(scale * np.abs(last_u) / 2, scale * np.abs(last_u) / 2),
)
is_flip = bool(last_u > 0)
scale_img = pygame.transform.flip(scale_img, is_flip, True)
surf.blit(
scale_img,
(
offset - scale_img.get_rect().centerx,
offset - scale_img.get_rect().centery,
),
)
# drawing axle
gfxdraw.aacircle(surf, offset, offset, int(0.05 * scale), (0, 0, 0))
gfxdraw.filled_circle(surf, offset, offset, int(0.05 * scale), (0, 0, 0))
surf = pygame.transform.flip(surf, False, True)
screen.blit(surf, (0, 0))
return (screen, clock, last_u), np.transpose(
np.array(pygame.surfarray.pixels3d(screen)), axes=(1, 0, 2)
)
def render_init(
self, screen_width: int = 600, screen_height: int = 400
) -> RenderStateType:
try:
import pygame
except ImportError:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gymnasium[classic_control]`"
)
pygame.init()
screen = pygame.Surface((screen_width, screen_height))
clock = pygame.time.Clock()
return screen, clock, None
def render_close(self, render_state: RenderStateType) -> None:
try:
import pygame
except ImportError:
raise DependencyNotInstalled(
"pygame is not installed, run `pip install gymnasium[classic_control]`"
)
pygame.display.quit()
pygame.quit()
class PendulumJaxEnv(JaxEnv, EzPickle):
metadata = {"render_modes": ["rgb_array"], "render_fps": 30}
def __init__(self, render_mode: Optional[str] = None, **kwargs):
EzPickle.__init__(self, render_mode=render_mode, **kwargs)
env = PendulumF(**kwargs)
env.transform(jax.jit)
action_space = env.action_space
observation_space = env.observation_space
metadata = {"render_modes": ["rgb_array"], "render_fps": 30}
super().__init__(
env,
observation_space=observation_space,
action_space=action_space,
metadata=metadata,
render_mode=render_mode,
)

96
gymnasium/functional.py Normal file
View File

@@ -0,0 +1,96 @@
"""Base class and definitions for an alternative, functional backend for gym envs, particularly suitable for hardware accelerated and otherwise transformed environments."""
from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar
import numpy as np
StateType = TypeVar("StateType")
ActType = TypeVar("ActType")
ObsType = TypeVar("ObsType")
RewardType = TypeVar("RewardType")
TerminalType = TypeVar("TerminalType")
RenderStateType = TypeVar("RenderStateType")
class FuncEnv(
Generic[StateType, ObsType, ActType, RewardType, TerminalType, RenderStateType]
):
"""Base class (template) for functional envs.
This API is meant to be used in a stateless manner, with the environment state being passed around explicitly.
That being said, nothing here prevents users from using the environment statefully, it's just not recommended.
A functional env consists of the following functions (in this case, instance methods):
- initial: returns the initial state of the POMDP
- observation: returns the observation in a given state
- transition: returns the next state after taking an action in a given state
- reward: returns the reward for a given (state, action, next_state) tuple
- terminal: returns whether a given state is terminal
- state_info: optional, returns a dict of info about a given state
- step_info: optional, returns a dict of info about a given (state, action, next_state) tuple
The class-based structure serves the purpose of allowing environment constants to be defined in the class,
and then using them by name in the code itself.
For the moment, this is predominantly for internal use. This API is likely to change, but in the future
we intend to flesh it out and officially expose it to end users.
"""
def __init__(self, options: Optional[Dict[str, Any]] = None):
"""Initialize the environment constants."""
self.__dict__.update(options or {})
def initial(self, rng: Any) -> StateType:
"""Initial state."""
raise NotImplementedError
def observation(self, state: StateType) -> ObsType:
"""Observation."""
raise NotImplementedError
def transition(self, state: StateType, action: ActType, rng: Any) -> StateType:
"""Transition."""
raise NotImplementedError
def reward(
self, state: StateType, action: ActType, next_state: StateType
) -> RewardType:
"""Reward."""
raise NotImplementedError
def terminal(self, state: StateType) -> TerminalType:
"""Terminal state."""
raise NotImplementedError
def state_info(self, state: StateType) -> dict:
"""Info dict about a single state."""
return {}
def step_info(
self, state: StateType, action: ActType, next_state: StateType
) -> dict:
"""Info dict about a full transition."""
return {}
def transform(self, func: Callable[[Callable], Callable]):
"""Functional transformations."""
self.initial = func(self.initial)
self.transition = func(self.transition)
self.observation = func(self.observation)
self.reward = func(self.reward)
self.terminal = func(self.terminal)
self.state_info = func(self.state_info)
self.step_info = func(self.step_info)
def render_image(
self, state: StateType, render_state: RenderStateType
) -> Tuple[RenderStateType, np.ndarray]:
"""Show the state."""
raise NotImplementedError
def render_init(self, **kwargs) -> RenderStateType:
"""Initialize the render state."""
raise NotImplementedError
def render_close(self, render_state: RenderStateType):
"""Close the render state."""
raise NotImplementedError

View File

@@ -12,4 +12,4 @@ pygame==2.1.0
ale-py~=0.8.0
mujoco==2.2
mujoco_py<2.2,>=2.1
imageio>=2.14.1
imageio>=2.14.1

View File

@@ -40,6 +40,7 @@ extras = {
"mujoco_py": ["mujoco_py<2.2,>=2.1"],
"mujoco": ["mujoco==2.2", "imageio>=2.14.1"],
"toy_text": ["pygame==2.1.0"],
"jax": ["jax==0.3.20", "jaxlib==0.3.20"],
"other": ["lz4>=3.1.0", "opencv-python>=3.0", "matplotlib>=3.0", "moviepy>=1.0.0"],
}

View File

View File

@@ -0,0 +1,56 @@
from typing import Any, Dict, Optional
import numpy as np
from gymnasium.functional import FuncEnv
class TestEnv(FuncEnv):
def __init__(self, options: Optional[Dict[str, Any]] = None):
super().__init__(options)
def initial(self, rng: Any) -> np.ndarray:
return np.array([0, 0], dtype=np.float32)
def observation(self, state: np.ndarray) -> np.ndarray:
return state
def transition(self, state: np.ndarray, action: int, rng: None) -> np.ndarray:
return state + np.array([0, action], dtype=np.float32)
def reward(self, state: np.ndarray, action: int, next_state: np.ndarray) -> float:
return 1.0 if next_state[1] > 0 else 0.0
def terminal(self, state: np.ndarray) -> bool:
return state[1] > 0
def test_api():
env = TestEnv()
state = env.initial(None)
obs = env.observation(state)
assert state.shape == (2,)
assert state.dtype == np.float32
assert obs.shape == (2,)
assert obs.dtype == np.float32
assert np.allclose(obs, state)
actions = [-1, -2, -5, 3, 5, 2]
for i, action in enumerate(actions):
next_state = env.transition(state, action, None)
assert next_state.shape == (2,)
assert next_state.dtype == np.float32
assert np.allclose(next_state, state + np.array([0, action]))
observation = env.observation(next_state)
assert observation.shape == (2,)
assert observation.dtype == np.float32
assert np.allclose(observation, next_state)
reward = env.reward(state, action, next_state)
assert reward == (1.0 if next_state[1] > 0 else 0.0)
terminal = env.terminal(next_state)
assert terminal == (i == 5) # terminal state is in the final action
state = next_state

View File

@@ -0,0 +1,105 @@
import jax
import jax.numpy as jnp
import jax.random as jrng
import numpy as np
import pytest
from gymnasium.envs.phys2d.cartpole import CartPoleF # noqa: E402
from gymnasium.envs.phys2d.pendulum import PendulumF # noqa: E402
@pytest.mark.parametrize("env_class", [CartPoleF, PendulumF])
def test_normal(env_class):
env = env_class()
rng = jrng.PRNGKey(0)
state = env.initial(rng)
env.action_space.seed(0)
for t in range(10):
obs = env.observation(state)
action = env.action_space.sample()
next_state = env.transition(state, action, None)
reward = env.reward(state, action, next_state)
terminal = env.terminal(next_state)
assert next_state.shape == state.shape
try:
float(reward)
except ValueError:
pytest.fail("Reward is not castable to float")
try:
bool(terminal)
except ValueError:
pytest.fail("Terminal is not castable to bool")
assert next_state.dtype == jnp.float32
assert isinstance(obs, jnp.ndarray)
assert obs.dtype == jnp.float32
state = next_state
@pytest.mark.parametrize("env_class", [CartPoleF, PendulumF])
def test_jit(env_class):
env = env_class()
rng = jrng.PRNGKey(0)
env.transform(jax.jit)
state = env.initial(rng)
env.action_space.seed(0)
for t in range(10):
obs = env.observation(state)
action = env.action_space.sample()
next_state = env.transition(state, action, None)
reward = env.reward(state, action, next_state)
terminal = env.terminal(next_state)
assert next_state.shape == state.shape
try:
float(reward)
except ValueError:
pytest.fail("Reward is not castable to float")
try:
bool(terminal)
except ValueError:
pytest.fail("Terminal is not castable to bool")
assert next_state.dtype == jnp.float32
assert isinstance(obs, jnp.ndarray)
assert obs.dtype == jnp.float32
state = next_state
@pytest.mark.parametrize("env_class", [CartPoleF, PendulumF])
def test_vmap(env_class):
env = env_class()
num_envs = 10
rng = jrng.split(jrng.PRNGKey(0), num_envs)
env.transform(jax.vmap)
env.transform(jax.jit)
state = env.initial(rng)
env.action_space.seed(0)
for t in range(10):
obs = env.observation(state)
action = jnp.array([env.action_space.sample() for _ in range(num_envs)])
# if isinstance(env.action_space, Discrete):
# action = action.reshape((num_envs, 1))
next_state = env.transition(state, action, None)
terminal = env.terminal(next_state)
reward = env.reward(state, action, next_state)
assert next_state.shape == state.shape
assert next_state.dtype == jnp.float32
assert reward.shape == (num_envs,)
assert reward.dtype == jnp.float32
assert terminal.shape == (num_envs,)
assert terminal.dtype == np.bool
assert isinstance(obs, jnp.ndarray)
assert obs.dtype == jnp.float32
state = next_state

View File

@@ -171,7 +171,6 @@ def test_render_modes(spec):
env = spec.make()
assert "rgb_array" in env.metadata["render_modes"]
assert "human" in env.metadata["render_modes"]
for mode in env.metadata["render_modes"]:
if mode != "human":

View File

@@ -18,173 +18,3 @@ CartPole-v1
"""
assert out == correct_out
def test_pprint_registry():
"""Testing the default registry, with no changes."""
out = gym.pprint_registry(disable_print=True)
correct_out = """===== classic_control =====
Acrobot-v1
CartPole-v0
CartPole-v1
MountainCar-v0
MountainCarContinuous-v0
Pendulum-v1
===== box2d =====
BipedalWalker-v3
BipedalWalkerHardcore-v3
CarRacing-v2
LunarLander-v2
LunarLanderContinuous-v2
===== toy_text =====
Blackjack-v1
CliffWalking-v0
FrozenLake-v1
FrozenLake8x8-v1
Taxi-v3
===== mujoco =====
Ant-v2 Ant-v3 Ant-v4
HalfCheetah-v2 HalfCheetah-v3 HalfCheetah-v4
Hopper-v2 Hopper-v3 Hopper-v4
Humanoid-v2 Humanoid-v3 Humanoid-v4
HumanoidStandup-v2 HumanoidStandup-v4 InvertedDoublePendulum-v2
InvertedDoublePendulum-v4 InvertedPendulum-v2 InvertedPendulum-v4
Pusher-v2 Pusher-v4 Reacher-v2
Reacher-v4 Swimmer-v2 Swimmer-v3
Swimmer-v4 Walker2d-v2 Walker2d-v3
Walker2d-v4
===== external =====
GymV26Environment-v0
===== utils_envs =====
RegisterDuringMakeEnv-v0
test.ArgumentEnv-v0
test.OrderlessArgumentEnv-v0
===== test =====
test/NoHuman-v0
test/NoHumanNoRGB-v0
test/NoHumanOldAPI-v0
"""
assert out == correct_out
def test_pprint_registry_exclude_namespaces():
"""Testing the default registry, with no changes."""
out = gym.pprint_registry(
max_rows=20, exclude_namespaces=["classic_control"], disable_print=True
)
correct_out = """===== box2d =====
BipedalWalker-v3
BipedalWalkerHardcore-v3
CarRacing-v2
LunarLander-v2
LunarLanderContinuous-v2
===== toy_text =====
Blackjack-v1
CliffWalking-v0
FrozenLake-v1
FrozenLake8x8-v1
Taxi-v3
===== mujoco =====
Ant-v2 Ant-v3
Ant-v4 HalfCheetah-v2
HalfCheetah-v3 HalfCheetah-v4
Hopper-v2 Hopper-v3
Hopper-v4 Humanoid-v2
Humanoid-v3 Humanoid-v4
HumanoidStandup-v2 HumanoidStandup-v4
InvertedDoublePendulum-v2 InvertedDoublePendulum-v4
InvertedPendulum-v2 InvertedPendulum-v4
Pusher-v2 Pusher-v4
Reacher-v2 Reacher-v4
Swimmer-v2 Swimmer-v3
Swimmer-v4 Walker2d-v2
Walker2d-v3 Walker2d-v4
===== external =====
GymV26Environment-v0
===== utils_envs =====
RegisterDuringMakeEnv-v0
test.ArgumentEnv-v0
test.OrderlessArgumentEnv-v0
===== test =====
test/NoHuman-v0
test/NoHumanNoRGB-v0
test/NoHumanOldAPI-v0
"""
assert out == correct_out
def test_pprint_registry_no_entry_point():
"""Test registry if there is environment with no entry point."""
gym.register("NoNamespaceEnv", "no-entry-point")
out = gym.pprint_registry(disable_print=True)
correct_out = """===== classic_control =====
Acrobot-v1
CartPole-v0
CartPole-v1
MountainCar-v0
MountainCarContinuous-v0
Pendulum-v1
===== box2d =====
BipedalWalker-v3
BipedalWalkerHardcore-v3
CarRacing-v2
LunarLander-v2
LunarLanderContinuous-v2
===== toy_text =====
Blackjack-v1
CliffWalking-v0
FrozenLake-v1
FrozenLake8x8-v1
Taxi-v3
===== mujoco =====
Ant-v2 Ant-v3 Ant-v4
HalfCheetah-v2 HalfCheetah-v3 HalfCheetah-v4
Hopper-v2 Hopper-v3 Hopper-v4
Humanoid-v2 Humanoid-v3 Humanoid-v4
HumanoidStandup-v2 HumanoidStandup-v4 InvertedDoublePendulum-v2
InvertedDoublePendulum-v4 InvertedPendulum-v2 InvertedPendulum-v4
Pusher-v2 Pusher-v4 Reacher-v2
Reacher-v4 Swimmer-v2 Swimmer-v3
Swimmer-v4 Walker2d-v2 Walker2d-v3
Walker2d-v4
===== external =====
GymV26Environment-v0
===== utils_envs =====
RegisterDuringMakeEnv-v0
test.ArgumentEnv-v0
test.OrderlessArgumentEnv-v0
===== test =====
test/NoHuman-v0
test/NoHumanNoRGB-v0
test/NoHumanOldAPI-v0
===== NoNamespaceEnv =====
NoNamespaceEnv
"""
assert out == correct_out
del gym.envs.registry["NoNamespaceEnv"]

View File

@@ -2,7 +2,7 @@ import pytest
import gymnasium as gym
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
from gymnasium.wrappers import OrderEnforcing, TimeLimit, TransformObservation
from gymnasium.wrappers import TimeLimit, TransformObservation
from gymnasium.wrappers.env_checker import PassiveEnvChecker
from tests.wrappers.utils import has_wrapper
@@ -39,8 +39,6 @@ def test_vector_make_wrappers():
sub_env = env.envs[0]
assert isinstance(sub_env, gym.Env)
assert sub_env.spec is not None
if sub_env.spec.order_enforce:
assert has_wrapper(sub_env, OrderEnforcing)
if sub_env.spec.max_episode_steps is not None:
assert has_wrapper(sub_env, TimeLimit)