Jax environment return jax data rather than numpy data (#817)

Co-authored-by: pseudo-rnd-thoughts <mark.m.towers@gmail.com>
This commit is contained in:
Ariel Kwiatkowski
2024-04-05 18:21:10 +02:00
committed by GitHub
parent f0202ae350
commit d43037920f
12 changed files with 48 additions and 81 deletions

View File

@@ -60,6 +60,7 @@ register(
vector_entry_point="gymnasium.envs.phys2d.cartpole:CartPoleJaxVectorEnv",
max_episode_steps=200,
reward_threshold=195.0,
disable_env_checker=True,
)
register(
@@ -68,6 +69,7 @@ register(
vector_entry_point="gymnasium.envs.phys2d.cartpole:CartPoleJaxVectorEnv",
max_episode_steps=500,
reward_threshold=475.0,
disable_env_checker=True,
)
register(
@@ -75,6 +77,7 @@ register(
entry_point="gymnasium.envs.phys2d.pendulum:PendulumJaxEnv",
vector_entry_point="gymnasium.envs.phys2d.pendulum:PendulumJaxVectorEnv",
max_episode_steps=200,
disable_env_checker=True,
)
# Box2d
@@ -161,11 +164,13 @@ register(
register(
id="tabular/Blackjack-v0",
entry_point="gymnasium.envs.tabular.blackjack:BlackJackJaxEnv",
disable_env_checker=True,
)
register(
id="tabular/CliffWalking-v0",
entry_point="gymnasium.envs.tabular.cliffwalking:CliffWalkingJaxEnv",
disable_env_checker=True,
)

View File

@@ -6,14 +6,12 @@ from typing import Any
import jax
import jax.numpy as jnp
import jax.random as jrng
import numpy as np
import gymnasium as gym
from gymnasium.envs.registration import EnvSpec
from gymnasium.functional import ActType, FuncEnv, StateType
from gymnasium.utils import seeding
from gymnasium.vector.utils import batch_space
from gymnasium.wrappers.jax_to_numpy import jax_to_numpy
class FunctionalJaxEnv(gym.Env):
@@ -32,7 +30,8 @@ class FunctionalJaxEnv(gym.Env):
):
"""Initialize the environment from a FuncEnv."""
if metadata is None:
metadata = {"render_mode": []}
# metadata.get("jax", False) can be used downstream to know that the environment returns jax arrays
metadata = {"render_mode": [], "jax": True}
self.func_env = func_env
@@ -45,8 +44,6 @@ class FunctionalJaxEnv(gym.Env):
self.spec = spec
self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
if self.render_mode == "rgb_array":
self.render_state = self.func_env.render_init()
else:
@@ -69,20 +66,10 @@ class FunctionalJaxEnv(gym.Env):
obs = self.func_env.observation(self.state)
info = self.func_env.state_info(self.state)
obs = jax_to_numpy(obs)
return obs, info
def step(self, action: ActType):
"""Steps through the environment using the action."""
if self._is_box_action_space:
assert isinstance(self.action_space, gym.spaces.Box) # For typing
action = np.clip(action, self.action_space.low, self.action_space.high)
else: # Discrete
# For now we assume jax envs don't use complex spaces
err_msg = f"{action!r} ({type(action)}) invalid"
assert self.action_space.contains(action), err_msg
rng, self.rng = jrng.split(self.rng)
next_state = self.func_env.transition(self.state, action, rng)
@@ -92,8 +79,6 @@ class FunctionalJaxEnv(gym.Env):
info = self.func_env.transition_info(self.state, action, next_state)
self.state = next_state
observation = jax_to_numpy(observation)
return observation, float(reward), bool(terminated), False, info
def render(self):
@@ -153,8 +138,6 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
self.autoreset_envs = jnp.zeros(self.num_envs, dtype=jnp.bool_)
self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
if self.render_mode == "rgb_array":
self.render_state = self.func_env.render_init()
else:
@@ -183,20 +166,10 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
self.steps = jnp.zeros(self.num_envs, dtype=jnp.int32)
obs = jax_to_numpy(obs)
return obs, info
def step(self, action: ActType):
"""Steps through the environment using the action."""
if self._is_box_action_space:
assert isinstance(self.action_space, gym.spaces.Box) # For typing
action = np.clip(action, self.action_space.low, self.action_space.high)
else: # Discrete
# For now we assume jax envs don't use complex spaces
assert self.action_space.contains(
action
), f"{action!r} ({type(action)}) invalid"
self.steps += 1
rng, self.rng = jrng.split(self.rng)
@@ -232,12 +205,6 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
self.autoreset_envs = done
observation = self.func_env.observation(next_state)
observation = jax_to_numpy(observation)
reward = jax_to_numpy(reward)
terminated = jax_to_numpy(terminated)
truncated = jax_to_numpy(truncated)
self.state = next_state

View File

@@ -245,7 +245,7 @@ class CartPoleFunctional(
class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle):
"""Jax-based implementation of the CartPole environment."""
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}
def __init__(self, render_mode: str | None = None, **kwargs: Any):
"""Constructor for the CartPole where the kwargs are applied to the functional environment."""
@@ -265,7 +265,7 @@ class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle):
class CartPoleJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
"""Jax-based implementation of the vectorized CartPole environment."""
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}
def __init__(
self,

View File

@@ -223,7 +223,7 @@ class PendulumFunctional(
class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
"""Jax-based pendulum environment using the functional version as base."""
metadata = {"render_modes": ["rgb_array"], "render_fps": 30}
metadata = {"render_modes": ["rgb_array"], "render_fps": 30, "jax": True}
def __init__(self, render_mode: str | None = None, **kwargs: Any):
"""Constructor where the kwargs are passed to the base environment to modify the parameters."""
@@ -242,7 +242,7 @@ class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
class PendulumJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
"""Jax-based implementation of the vectorized CartPole environment."""
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}
def __init__(
self,

View File

@@ -496,7 +496,7 @@ class BlackjackFunctional(
class BlackJackJaxEnv(FunctionalJaxEnv, EzPickle):
"""A Gymnasium Env wrapper for the functional blackjack env."""
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}
def __init__(self, render_mode: Optional[str] = None, **kwargs):
"""Initializes Gym wrapper for blackjack functional env."""

View File

@@ -358,7 +358,7 @@ class CliffWalkingFunctional(
class CliffWalkingJaxEnv(FunctionalJaxEnv, EzPickle):
"""A Gymnasium Env wrapper for the functional cliffwalking env."""
metadata = {"render_modes": ["rgb_array"], "render_fps": 50}
metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}
def __init__(self, render_mode: str | None = None, **kwargs):
"""Initializes Gym wrapper for cliffwalking functional env."""

View File

@@ -367,6 +367,11 @@ def check_env(
f"The environment ({env}) is different from the unwrapped version ({env.unwrapped}). This could effect the environment checker as the environment most likely has a wrapper applied to it. We recommend using the raw environment for `check_env` using `env.unwrapped`."
)
if env.metadata.get("jax", False):
env = gym.wrappers.JaxToNumpy(env)
elif env.metadata.get("torch", False):
env = gym.wrappers.JaxToTorch(env)
# ============= Check the spaces (observation and action) ================
if not hasattr(env, "action_space"):
raise AttributeError(

View File

@@ -120,7 +120,7 @@ def test_vectorized(env_class):
obs, info = env.reset(seed=0)
assert obs.shape == (10,) + env.single_observation_space.shape
assert isinstance(obs, np.ndarray)
assert isinstance(obs, jax.Array)
assert isinstance(info, dict)
for t in range(100):
@@ -128,13 +128,13 @@ def test_vectorized(env_class):
obs, reward, terminated, truncated, info = env.step(action)
assert obs.shape == (10,) + env.single_observation_space.shape
assert isinstance(obs, np.ndarray)
assert isinstance(obs, jax.Array)
assert reward.shape == (10,)
assert isinstance(reward, np.ndarray)
assert isinstance(reward, jax.Array)
assert terminated.shape == (10,)
assert isinstance(terminated, np.ndarray)
assert isinstance(terminated, jax.Array)
assert truncated.shape == (10,)
assert isinstance(truncated, np.ndarray)
assert isinstance(truncated, jax.Array)
assert isinstance(info, dict)
# These were removed in the new autoreset order

View File

@@ -70,6 +70,9 @@ def test_discrete_actions_out_of_bound(env: gym.Env):
Args:
env (gym.Env): the gymnasium environment
"""
if env.metadata.get("jax", False):
return
assert isinstance(env.action_space, spaces.Discrete)
upper_bound = env.action_space.start + env.action_space.n - 1
@@ -102,6 +105,9 @@ def test_box_actions_out_of_bound(env: gym.Env):
Args:
env (gym.Env): the gymnasium environment
"""
if env.metadata.get("jax", False):
return
env.reset(seed=42)
assert env.spec is not None

View File

@@ -7,11 +7,7 @@ import pytest
import gymnasium as gym
from gymnasium.envs.registration import EnvSpec
from gymnasium.utils.env_checker import check_env, data_equivalence
from tests.envs.utils import (
all_testing_env_specs,
all_testing_initialised_envs,
assert_equals,
)
from tests.envs.utils import all_testing_env_specs, all_testing_initialised_envs
# This runs a smoketest on each official registered env. We may want
@@ -42,6 +38,7 @@ def test_all_env_api(spec):
"""Check that all environments pass the environment checker with no warnings other than the expected."""
with warnings.catch_warnings(record=True) as caught_warnings:
env = spec.make().unwrapped
check_env(env, skip_render_check=True)
env.close()
@@ -98,9 +95,13 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
env_1 = env_spec.make(disable_env_checker=True)
env_2 = env_spec.make(disable_env_checker=True)
if env_1.metadata.get("jax", False):
env_1 = gym.wrappers.JaxToNumpy(env_1)
env_2 = gym.wrappers.JaxToNumpy(env_2)
initial_obs_1, initial_info_1 = env_1.reset(seed=SEED)
initial_obs_2, initial_info_2 = env_2.reset(seed=SEED)
assert_equals(initial_obs_1, initial_obs_2)
assert data_equivalence(initial_obs_1, initial_obs_2, exact=True)
env_1.action_space.seed(SEED)
@@ -111,7 +112,9 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
obs_1, rew_1, terminated_1, truncated_1, info_1 = env_1.step(action)
obs_2, rew_2, terminated_2, truncated_2, info_2 = env_2.step(action)
assert_equals(obs_1, obs_2, f"[{time_step}] ")
assert data_equivalence(
obs_1, obs_2, exact=True
), f"[{time_step}] obs_1={obs_1}, obs_2={obs_2}"
assert env_1.observation_space.contains(
obs_1
) # obs_2 verified by previous assertion
@@ -123,7 +126,9 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
assert (
truncated_1 == truncated_2
), f"[{time_step}] done 1={truncated_1}, done 2={truncated_2}"
assert_equals(info_1, info_2, f"[{time_step}] ")
assert data_equivalence(
info_1, info_2, exact=True
), f"[{time_step}] info_1={info_1}, info_2={info_2}"
if (
terminated_1 or truncated_1
@@ -141,6 +146,9 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None],
)
def test_pickle_env(env: gym.Env):
if env.metadata.get("jax", False):
env = gym.wrappers.JaxToNumpy(env)
pickled_env = pickle.loads(pickle.dumps(env))
data_equivalence(env.reset(), pickled_env.reset())

View File

@@ -1,8 +1,6 @@
"""Finds all the specs that we can test with"""
from typing import List, Optional
import numpy as np
import gymnasium as gym
from gymnasium import logger
from gymnasium.envs.registration import EnvSpec
@@ -55,28 +53,3 @@ gym_testing_env_specs: List[EnvSpec] = [
for ep in ["box2d", "classic_control", "toy_text"]
)
]
def assert_equals(a, b, prefix=None):
"""Assert equality of data structures `a` and `b`.
Args:
a: first data structure
b: second data structure
prefix: prefix for failed assertion message for types and dicts
"""
assert type(a) is type(b), f"{prefix}Differing types: {a} and {b}"
if isinstance(a, dict):
assert list(a.keys()) == list(b.keys()), f"{prefix}Key sets differ: {a} and {b}"
for k in a.keys():
v_a = a[k]
v_b = b[k]
assert_equals(v_a, v_b)
elif isinstance(a, np.ndarray):
np.testing.assert_array_equal(a, b)
elif isinstance(a, tuple):
for elem_from_a, elem_from_b in zip(a, b):
assert_equals(elem_from_a, elem_from_b)
else:
assert a == b

View File

@@ -19,6 +19,9 @@ from tests.testing_env import GenericTestEnv
ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None],
)
def test_passive_checker_wrapper_warnings(env):
if env.spec is not None and env.spec.disable_env_checker:
return
with warnings.catch_warnings(record=True) as caught_warnings:
checker_env = PassiveEnvChecker(env)
checker_env.reset()