Files
Gymnasium/tests/functional/test_func_jax_env.py

148 lines
4.4 KiB
Python

"""Test the functional jax environment."""
import numpy as np
import pytest
pytest.skip(
"Github CI is running forever for the tests in this file.", allow_module_level=True
)
jax = pytest.importorskip("jax")
import jax.numpy as jnp # noqa: E402
import jax.random as jrng # noqa: E402
import gymnasium as gym # noqa: E402
from gymnasium.envs.phys2d.cartpole import CartPoleFunctional # noqa: E402
from gymnasium.envs.phys2d.pendulum import PendulumFunctional # noqa: E402
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
def test_without_transform(env_class):
"""Tests the environment without transforming the environment."""
env = env_class()
rng = jrng.PRNGKey(0)
state = env.initial(rng)
env.action_space.seed(0)
for t in range(10):
obs = env.observation(state, rng)
action = env.action_space.sample()
next_state = env.transition(state, action, rng)
reward = env.reward(state, action, next_state, rng)
terminal = env.terminal(next_state, rng)
assert next_state.shape == state.shape
try:
float(reward)
except ValueError:
pytest.fail("Reward is not castable to float")
try:
bool(terminal)
except ValueError:
pytest.fail("Terminal is not castable to bool")
assert next_state.dtype == jnp.float32
assert isinstance(obs, jax.Array)
assert obs.dtype == jnp.float32
state = next_state
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
def test_jit(env_class):
"""Tests jitting the functional instance functions."""
env = env_class()
rng = jrng.PRNGKey(0)
env.transform(jax.jit)
state = env.initial(rng)
env.action_space.seed(0)
for t in range(10):
obs = env.observation(state, rng)
action = env.action_space.sample()
next_state = env.transition(state, action, rng)
reward = env.reward(state, action, next_state, rng)
terminal = env.terminal(next_state, rng)
assert next_state.shape == state.shape
try:
float(reward)
except ValueError:
pytest.fail("Reward is not castable to float")
try:
bool(terminal)
except ValueError:
pytest.fail("Terminal is not castable to bool")
assert next_state.dtype == jnp.float32
assert isinstance(obs, jax.Array)
assert obs.dtype == jnp.float32
state = next_state
@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional])
def test_vmap(env_class):
"""Tests vmap of functional instance functions with transform."""
env = env_class()
num_envs = 10
rng = jrng.split(jrng.PRNGKey(0), num_envs)
env.transform(jax.vmap)
env.transform(jax.jit)
state = env.initial(rng)
env.action_space.seed(0)
for t in range(10):
obs = env.observation(state, rng)
action = jnp.array([env.action_space.sample() for _ in range(num_envs)])
# if isinstance(env.action_space, Discrete):
# action = action.reshape((num_envs, 1))
next_state = env.transition(state, action, rng)
terminal = env.terminal(next_state, rng)
reward = env.reward(state, action, next_state, rng)
assert next_state.shape == state.shape
assert next_state.dtype == jnp.float32
assert reward.shape == (num_envs,)
assert reward.dtype == jnp.float32
assert terminal.shape == (num_envs,)
assert terminal.dtype == bool
assert isinstance(obs, jax.Array)
assert obs.dtype == jnp.float32
state = next_state
@pytest.mark.parametrize("vectorization_mode", ["vector_entry_point", "sync", "async"])
def test_equal_episode_length(vectorization_mode: str):
"""Tests that the number of steps in an episode is the same."""
env = gym.make_vec("phys2d/Pendulum-v0", 2, vectorization_mode=vectorization_mode)
# By default, the total number of steps per episode is 200
expected_dones = [199, 399, 599, 799, 999]
env.action_space.seed(0)
env.reset()
for t in range(1000):
actions = env.action_space.sample()
next_obs, reward, term, trunc, info = env.step(actions)
done = np.logical_or(term, trunc).any()
if done:
assert t in expected_dones
else:
assert t not in expected_dones
if done:
obs, *_ = env.step(actions)