mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-15 11:18:47 +00:00
Jax environment return jax data rather than numpy data (#817)
Co-authored-by: pseudo-rnd-thoughts <mark.m.towers@gmail.com>
This commit is contained in:
committed by
GitHub
parent
f0202ae350
commit
d43037920f
@@ -60,6 +60,7 @@ register(
|
||||
vector_entry_point="gymnasium.envs.phys2d.cartpole:CartPoleJaxVectorEnv",
|
||||
max_episode_steps=200,
|
||||
reward_threshold=195.0,
|
||||
disable_env_checker=True,
|
||||
)
|
||||
|
||||
register(
|
||||
@@ -68,6 +69,7 @@ register(
|
||||
vector_entry_point="gymnasium.envs.phys2d.cartpole:CartPoleJaxVectorEnv",
|
||||
max_episode_steps=500,
|
||||
reward_threshold=475.0,
|
||||
disable_env_checker=True,
|
||||
)
|
||||
|
||||
register(
|
||||
@@ -75,6 +77,7 @@ register(
|
||||
entry_point="gymnasium.envs.phys2d.pendulum:PendulumJaxEnv",
|
||||
vector_entry_point="gymnasium.envs.phys2d.pendulum:PendulumJaxVectorEnv",
|
||||
max_episode_steps=200,
|
||||
disable_env_checker=True,
|
||||
)
|
||||
|
||||
# Box2d
|
||||
@@ -161,11 +164,13 @@ register(
|
||||
register(
|
||||
id="tabular/Blackjack-v0",
|
||||
entry_point="gymnasium.envs.tabular.blackjack:BlackJackJaxEnv",
|
||||
disable_env_checker=True,
|
||||
)
|
||||
|
||||
register(
|
||||
id="tabular/CliffWalking-v0",
|
||||
entry_point="gymnasium.envs.tabular.cliffwalking:CliffWalkingJaxEnv",
|
||||
disable_env_checker=True,
|
||||
)
|
||||
|
||||
|
||||
|
@@ -6,14 +6,12 @@ from typing import Any
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.random as jrng
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import EnvSpec
|
||||
from gymnasium.functional import ActType, FuncEnv, StateType
|
||||
from gymnasium.utils import seeding
|
||||
from gymnasium.vector.utils import batch_space
|
||||
from gymnasium.wrappers.jax_to_numpy import jax_to_numpy
|
||||
|
||||
|
||||
class FunctionalJaxEnv(gym.Env):
|
||||
@@ -32,7 +30,8 @@ class FunctionalJaxEnv(gym.Env):
|
||||
):
|
||||
"""Initialize the environment from a FuncEnv."""
|
||||
if metadata is None:
|
||||
metadata = {"render_mode": []}
|
||||
# metadata.get("jax", False) can be used downstream to know that the environment returns jax arrays
|
||||
metadata = {"render_mode": [], "jax": True}
|
||||
|
||||
self.func_env = func_env
|
||||
|
||||
@@ -45,8 +44,6 @@ class FunctionalJaxEnv(gym.Env):
|
||||
|
||||
self.spec = spec
|
||||
|
||||
self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
|
||||
|
||||
if self.render_mode == "rgb_array":
|
||||
self.render_state = self.func_env.render_init()
|
||||
else:
|
||||
@@ -69,20 +66,10 @@ class FunctionalJaxEnv(gym.Env):
|
||||
obs = self.func_env.observation(self.state)
|
||||
info = self.func_env.state_info(self.state)
|
||||
|
||||
obs = jax_to_numpy(obs)
|
||||
|
||||
return obs, info
|
||||
|
||||
def step(self, action: ActType):
|
||||
"""Steps through the environment using the action."""
|
||||
if self._is_box_action_space:
|
||||
assert isinstance(self.action_space, gym.spaces.Box) # For typing
|
||||
action = np.clip(action, self.action_space.low, self.action_space.high)
|
||||
else: # Discrete
|
||||
# For now we assume jax envs don't use complex spaces
|
||||
err_msg = f"{action!r} ({type(action)}) invalid"
|
||||
assert self.action_space.contains(action), err_msg
|
||||
|
||||
rng, self.rng = jrng.split(self.rng)
|
||||
|
||||
next_state = self.func_env.transition(self.state, action, rng)
|
||||
@@ -92,8 +79,6 @@ class FunctionalJaxEnv(gym.Env):
|
||||
info = self.func_env.transition_info(self.state, action, next_state)
|
||||
self.state = next_state
|
||||
|
||||
observation = jax_to_numpy(observation)
|
||||
|
||||
return observation, float(reward), bool(terminated), False, info
|
||||
|
||||
def render(self):
|
||||
@@ -153,8 +138,6 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
|
||||
|
||||
self.autoreset_envs = jnp.zeros(self.num_envs, dtype=jnp.bool_)
|
||||
|
||||
self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
|
||||
|
||||
if self.render_mode == "rgb_array":
|
||||
self.render_state = self.func_env.render_init()
|
||||
else:
|
||||
@@ -183,20 +166,10 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
|
||||
|
||||
self.steps = jnp.zeros(self.num_envs, dtype=jnp.int32)
|
||||
|
||||
obs = jax_to_numpy(obs)
|
||||
|
||||
return obs, info
|
||||
|
||||
def step(self, action: ActType):
|
||||
"""Steps through the environment using the action."""
|
||||
if self._is_box_action_space:
|
||||
assert isinstance(self.action_space, gym.spaces.Box) # For typing
|
||||
action = np.clip(action, self.action_space.low, self.action_space.high)
|
||||
else: # Discrete
|
||||
# For now we assume jax envs don't use complex spaces
|
||||
assert self.action_space.contains(
|
||||
action
|
||||
), f"{action!r} ({type(action)}) invalid"
|
||||
self.steps += 1
|
||||
|
||||
rng, self.rng = jrng.split(self.rng)
|
||||
@@ -232,12 +205,6 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
|
||||
self.autoreset_envs = done
|
||||
|
||||
observation = self.func_env.observation(next_state)
|
||||
observation = jax_to_numpy(observation)
|
||||
|
||||
reward = jax_to_numpy(reward)
|
||||
|
||||
terminated = jax_to_numpy(terminated)
|
||||
truncated = jax_to_numpy(truncated)
|
||||
|
||||
self.state = next_state
|
||||
|
||||
|
@@ -245,7 +245,7 @@ class CartPoleFunctional(
|
||||
class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle):
|
||||
"""Jax-based implementation of the CartPole environment."""
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}
|
||||
|
||||
def __init__(self, render_mode: str | None = None, **kwargs: Any):
|
||||
"""Constructor for the CartPole where the kwargs are applied to the functional environment."""
|
||||
@@ -265,7 +265,7 @@ class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle):
|
||||
class CartPoleJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
|
||||
"""Jax-based implementation of the vectorized CartPole environment."""
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@@ -223,7 +223,7 @@ class PendulumFunctional(
|
||||
class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
|
||||
"""Jax-based pendulum environment using the functional version as base."""
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 30}
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 30, "jax": True}
|
||||
|
||||
def __init__(self, render_mode: str | None = None, **kwargs: Any):
|
||||
"""Constructor where the kwargs are passed to the base environment to modify the parameters."""
|
||||
@@ -242,7 +242,7 @@ class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
|
||||
class PendulumJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
|
||||
"""Jax-based implementation of the vectorized CartPole environment."""
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@@ -496,7 +496,7 @@ class BlackjackFunctional(
|
||||
class BlackJackJaxEnv(FunctionalJaxEnv, EzPickle):
|
||||
"""A Gymnasium Env wrapper for the functional blackjack env."""
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}
|
||||
|
||||
def __init__(self, render_mode: Optional[str] = None, **kwargs):
|
||||
"""Initializes Gym wrapper for blackjack functional env."""
|
||||
|
@@ -358,7 +358,7 @@ class CliffWalkingFunctional(
|
||||
class CliffWalkingJaxEnv(FunctionalJaxEnv, EzPickle):
|
||||
"""A Gymnasium Env wrapper for the functional cliffwalking env."""
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}
|
||||
|
||||
def __init__(self, render_mode: str | None = None, **kwargs):
|
||||
"""Initializes Gym wrapper for cliffwalking functional env."""
|
||||
|
@@ -367,6 +367,11 @@ def check_env(
|
||||
f"The environment ({env}) is different from the unwrapped version ({env.unwrapped}). This could effect the environment checker as the environment most likely has a wrapper applied to it. We recommend using the raw environment for `check_env` using `env.unwrapped`."
|
||||
)
|
||||
|
||||
if env.metadata.get("jax", False):
|
||||
env = gym.wrappers.JaxToNumpy(env)
|
||||
elif env.metadata.get("torch", False):
|
||||
env = gym.wrappers.JaxToTorch(env)
|
||||
|
||||
# ============= Check the spaces (observation and action) ================
|
||||
if not hasattr(env, "action_space"):
|
||||
raise AttributeError(
|
||||
|
@@ -120,7 +120,7 @@ def test_vectorized(env_class):
|
||||
|
||||
obs, info = env.reset(seed=0)
|
||||
assert obs.shape == (10,) + env.single_observation_space.shape
|
||||
assert isinstance(obs, np.ndarray)
|
||||
assert isinstance(obs, jax.Array)
|
||||
assert isinstance(info, dict)
|
||||
|
||||
for t in range(100):
|
||||
@@ -128,13 +128,13 @@ def test_vectorized(env_class):
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
|
||||
assert obs.shape == (10,) + env.single_observation_space.shape
|
||||
assert isinstance(obs, np.ndarray)
|
||||
assert isinstance(obs, jax.Array)
|
||||
assert reward.shape == (10,)
|
||||
assert isinstance(reward, np.ndarray)
|
||||
assert isinstance(reward, jax.Array)
|
||||
assert terminated.shape == (10,)
|
||||
assert isinstance(terminated, np.ndarray)
|
||||
assert isinstance(terminated, jax.Array)
|
||||
assert truncated.shape == (10,)
|
||||
assert isinstance(truncated, np.ndarray)
|
||||
assert isinstance(truncated, jax.Array)
|
||||
assert isinstance(info, dict)
|
||||
|
||||
# These were removed in the new autoreset order
|
||||
|
@@ -70,6 +70,9 @@ def test_discrete_actions_out_of_bound(env: gym.Env):
|
||||
Args:
|
||||
env (gym.Env): the gymnasium environment
|
||||
"""
|
||||
if env.metadata.get("jax", False):
|
||||
return
|
||||
|
||||
assert isinstance(env.action_space, spaces.Discrete)
|
||||
upper_bound = env.action_space.start + env.action_space.n - 1
|
||||
|
||||
@@ -102,6 +105,9 @@ def test_box_actions_out_of_bound(env: gym.Env):
|
||||
Args:
|
||||
env (gym.Env): the gymnasium environment
|
||||
"""
|
||||
if env.metadata.get("jax", False):
|
||||
return
|
||||
|
||||
env.reset(seed=42)
|
||||
|
||||
assert env.spec is not None
|
||||
|
@@ -7,11 +7,7 @@ import pytest
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import EnvSpec
|
||||
from gymnasium.utils.env_checker import check_env, data_equivalence
|
||||
from tests.envs.utils import (
|
||||
all_testing_env_specs,
|
||||
all_testing_initialised_envs,
|
||||
assert_equals,
|
||||
)
|
||||
from tests.envs.utils import all_testing_env_specs, all_testing_initialised_envs
|
||||
|
||||
|
||||
# This runs a smoketest on each official registered env. We may want
|
||||
@@ -42,6 +38,7 @@ def test_all_env_api(spec):
|
||||
"""Check that all environments pass the environment checker with no warnings other than the expected."""
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
env = spec.make().unwrapped
|
||||
|
||||
check_env(env, skip_render_check=True)
|
||||
|
||||
env.close()
|
||||
@@ -98,9 +95,13 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
|
||||
env_1 = env_spec.make(disable_env_checker=True)
|
||||
env_2 = env_spec.make(disable_env_checker=True)
|
||||
|
||||
if env_1.metadata.get("jax", False):
|
||||
env_1 = gym.wrappers.JaxToNumpy(env_1)
|
||||
env_2 = gym.wrappers.JaxToNumpy(env_2)
|
||||
|
||||
initial_obs_1, initial_info_1 = env_1.reset(seed=SEED)
|
||||
initial_obs_2, initial_info_2 = env_2.reset(seed=SEED)
|
||||
assert_equals(initial_obs_1, initial_obs_2)
|
||||
assert data_equivalence(initial_obs_1, initial_obs_2, exact=True)
|
||||
|
||||
env_1.action_space.seed(SEED)
|
||||
|
||||
@@ -111,7 +112,9 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
|
||||
obs_1, rew_1, terminated_1, truncated_1, info_1 = env_1.step(action)
|
||||
obs_2, rew_2, terminated_2, truncated_2, info_2 = env_2.step(action)
|
||||
|
||||
assert_equals(obs_1, obs_2, f"[{time_step}] ")
|
||||
assert data_equivalence(
|
||||
obs_1, obs_2, exact=True
|
||||
), f"[{time_step}] obs_1={obs_1}, obs_2={obs_2}"
|
||||
assert env_1.observation_space.contains(
|
||||
obs_1
|
||||
) # obs_2 verified by previous assertion
|
||||
@@ -123,7 +126,9 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
|
||||
assert (
|
||||
truncated_1 == truncated_2
|
||||
), f"[{time_step}] done 1={truncated_1}, done 2={truncated_2}"
|
||||
assert_equals(info_1, info_2, f"[{time_step}] ")
|
||||
assert data_equivalence(
|
||||
info_1, info_2, exact=True
|
||||
), f"[{time_step}] info_1={info_1}, info_2={info_2}"
|
||||
|
||||
if (
|
||||
terminated_1 or truncated_1
|
||||
@@ -141,6 +146,9 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
|
||||
ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None],
|
||||
)
|
||||
def test_pickle_env(env: gym.Env):
|
||||
if env.metadata.get("jax", False):
|
||||
env = gym.wrappers.JaxToNumpy(env)
|
||||
|
||||
pickled_env = pickle.loads(pickle.dumps(env))
|
||||
|
||||
data_equivalence(env.reset(), pickled_env.reset())
|
||||
|
@@ -1,8 +1,6 @@
|
||||
"""Finds all the specs that we can test with"""
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import logger
|
||||
from gymnasium.envs.registration import EnvSpec
|
||||
@@ -55,28 +53,3 @@ gym_testing_env_specs: List[EnvSpec] = [
|
||||
for ep in ["box2d", "classic_control", "toy_text"]
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def assert_equals(a, b, prefix=None):
|
||||
"""Assert equality of data structures `a` and `b`.
|
||||
|
||||
Args:
|
||||
a: first data structure
|
||||
b: second data structure
|
||||
prefix: prefix for failed assertion message for types and dicts
|
||||
"""
|
||||
assert type(a) is type(b), f"{prefix}Differing types: {a} and {b}"
|
||||
if isinstance(a, dict):
|
||||
assert list(a.keys()) == list(b.keys()), f"{prefix}Key sets differ: {a} and {b}"
|
||||
|
||||
for k in a.keys():
|
||||
v_a = a[k]
|
||||
v_b = b[k]
|
||||
assert_equals(v_a, v_b)
|
||||
elif isinstance(a, np.ndarray):
|
||||
np.testing.assert_array_equal(a, b)
|
||||
elif isinstance(a, tuple):
|
||||
for elem_from_a, elem_from_b in zip(a, b):
|
||||
assert_equals(elem_from_a, elem_from_b)
|
||||
else:
|
||||
assert a == b
|
||||
|
@@ -19,6 +19,9 @@ from tests.testing_env import GenericTestEnv
|
||||
ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None],
|
||||
)
|
||||
def test_passive_checker_wrapper_warnings(env):
|
||||
if env.spec is not None and env.spec.disable_env_checker:
|
||||
return
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
checker_env = PassiveEnvChecker(env)
|
||||
checker_env.reset()
|
||||
|
Reference in New Issue
Block a user