Jax environment return jax data rather than numpy data (#817)

Co-authored-by: pseudo-rnd-thoughts <mark.m.towers@gmail.com>
This commit is contained in:
Ariel Kwiatkowski
2024-04-05 18:21:10 +02:00
committed by GitHub
parent f0202ae350
commit d43037920f
12 changed files with 48 additions and 81 deletions

View File

@@ -60,6 +60,7 @@ register(
vector_entry_point="gymnasium.envs.phys2d.cartpole:CartPoleJaxVectorEnv", vector_entry_point="gymnasium.envs.phys2d.cartpole:CartPoleJaxVectorEnv",
max_episode_steps=200, max_episode_steps=200,
reward_threshold=195.0, reward_threshold=195.0,
disable_env_checker=True,
) )
register( register(
@@ -68,6 +69,7 @@ register(
vector_entry_point="gymnasium.envs.phys2d.cartpole:CartPoleJaxVectorEnv", vector_entry_point="gymnasium.envs.phys2d.cartpole:CartPoleJaxVectorEnv",
max_episode_steps=500, max_episode_steps=500,
reward_threshold=475.0, reward_threshold=475.0,
disable_env_checker=True,
) )
register( register(
@@ -75,6 +77,7 @@ register(
entry_point="gymnasium.envs.phys2d.pendulum:PendulumJaxEnv", entry_point="gymnasium.envs.phys2d.pendulum:PendulumJaxEnv",
vector_entry_point="gymnasium.envs.phys2d.pendulum:PendulumJaxVectorEnv", vector_entry_point="gymnasium.envs.phys2d.pendulum:PendulumJaxVectorEnv",
max_episode_steps=200, max_episode_steps=200,
disable_env_checker=True,
) )
# Box2d # Box2d
@@ -161,11 +164,13 @@ register(
register( register(
id="tabular/Blackjack-v0", id="tabular/Blackjack-v0",
entry_point="gymnasium.envs.tabular.blackjack:BlackJackJaxEnv", entry_point="gymnasium.envs.tabular.blackjack:BlackJackJaxEnv",
disable_env_checker=True,
) )
register( register(
id="tabular/CliffWalking-v0", id="tabular/CliffWalking-v0",
entry_point="gymnasium.envs.tabular.cliffwalking:CliffWalkingJaxEnv", entry_point="gymnasium.envs.tabular.cliffwalking:CliffWalkingJaxEnv",
disable_env_checker=True,
) )

View File

@@ -6,14 +6,12 @@ from typing import Any
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jax.random as jrng import jax.random as jrng
import numpy as np
import gymnasium as gym import gymnasium as gym
from gymnasium.envs.registration import EnvSpec from gymnasium.envs.registration import EnvSpec
from gymnasium.functional import ActType, FuncEnv, StateType from gymnasium.functional import ActType, FuncEnv, StateType
from gymnasium.utils import seeding from gymnasium.utils import seeding
from gymnasium.vector.utils import batch_space from gymnasium.vector.utils import batch_space
from gymnasium.wrappers.jax_to_numpy import jax_to_numpy
class FunctionalJaxEnv(gym.Env): class FunctionalJaxEnv(gym.Env):
@@ -32,7 +30,8 @@ class FunctionalJaxEnv(gym.Env):
): ):
"""Initialize the environment from a FuncEnv.""" """Initialize the environment from a FuncEnv."""
if metadata is None: if metadata is None:
metadata = {"render_mode": []} # metadata.get("jax", False) can be used downstream to know that the environment returns jax arrays
metadata = {"render_mode": [], "jax": True}
self.func_env = func_env self.func_env = func_env
@@ -45,8 +44,6 @@ class FunctionalJaxEnv(gym.Env):
self.spec = spec self.spec = spec
self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
if self.render_mode == "rgb_array": if self.render_mode == "rgb_array":
self.render_state = self.func_env.render_init() self.render_state = self.func_env.render_init()
else: else:
@@ -69,20 +66,10 @@ class FunctionalJaxEnv(gym.Env):
obs = self.func_env.observation(self.state) obs = self.func_env.observation(self.state)
info = self.func_env.state_info(self.state) info = self.func_env.state_info(self.state)
obs = jax_to_numpy(obs)
return obs, info return obs, info
def step(self, action: ActType): def step(self, action: ActType):
"""Steps through the environment using the action.""" """Steps through the environment using the action."""
if self._is_box_action_space:
assert isinstance(self.action_space, gym.spaces.Box) # For typing
action = np.clip(action, self.action_space.low, self.action_space.high)
else: # Discrete
# For now we assume jax envs don't use complex spaces
err_msg = f"{action!r} ({type(action)}) invalid"
assert self.action_space.contains(action), err_msg
rng, self.rng = jrng.split(self.rng) rng, self.rng = jrng.split(self.rng)
next_state = self.func_env.transition(self.state, action, rng) next_state = self.func_env.transition(self.state, action, rng)
@@ -92,8 +79,6 @@ class FunctionalJaxEnv(gym.Env):
info = self.func_env.transition_info(self.state, action, next_state) info = self.func_env.transition_info(self.state, action, next_state)
self.state = next_state self.state = next_state
observation = jax_to_numpy(observation)
return observation, float(reward), bool(terminated), False, info return observation, float(reward), bool(terminated), False, info
def render(self): def render(self):
@@ -153,8 +138,6 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
self.autoreset_envs = jnp.zeros(self.num_envs, dtype=jnp.bool_) self.autoreset_envs = jnp.zeros(self.num_envs, dtype=jnp.bool_)
self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box)
if self.render_mode == "rgb_array": if self.render_mode == "rgb_array":
self.render_state = self.func_env.render_init() self.render_state = self.func_env.render_init()
else: else:
@@ -183,20 +166,10 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
self.steps = jnp.zeros(self.num_envs, dtype=jnp.int32) self.steps = jnp.zeros(self.num_envs, dtype=jnp.int32)
obs = jax_to_numpy(obs)
return obs, info return obs, info
def step(self, action: ActType): def step(self, action: ActType):
"""Steps through the environment using the action.""" """Steps through the environment using the action."""
if self._is_box_action_space:
assert isinstance(self.action_space, gym.spaces.Box) # For typing
action = np.clip(action, self.action_space.low, self.action_space.high)
else: # Discrete
# For now we assume jax envs don't use complex spaces
assert self.action_space.contains(
action
), f"{action!r} ({type(action)}) invalid"
self.steps += 1 self.steps += 1
rng, self.rng = jrng.split(self.rng) rng, self.rng = jrng.split(self.rng)
@@ -232,12 +205,6 @@ class FunctionalJaxVectorEnv(gym.vector.VectorEnv):
self.autoreset_envs = done self.autoreset_envs = done
observation = self.func_env.observation(next_state) observation = self.func_env.observation(next_state)
observation = jax_to_numpy(observation)
reward = jax_to_numpy(reward)
terminated = jax_to_numpy(terminated)
truncated = jax_to_numpy(truncated)
self.state = next_state self.state = next_state

View File

@@ -245,7 +245,7 @@ class CartPoleFunctional(
class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle): class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle):
"""Jax-based implementation of the CartPole environment.""" """Jax-based implementation of the CartPole environment."""
metadata = {"render_modes": ["rgb_array"], "render_fps": 50} metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}
def __init__(self, render_mode: str | None = None, **kwargs: Any): def __init__(self, render_mode: str | None = None, **kwargs: Any):
"""Constructor for the CartPole where the kwargs are applied to the functional environment.""" """Constructor for the CartPole where the kwargs are applied to the functional environment."""
@@ -265,7 +265,7 @@ class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle):
class CartPoleJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle): class CartPoleJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
"""Jax-based implementation of the vectorized CartPole environment.""" """Jax-based implementation of the vectorized CartPole environment."""
metadata = {"render_modes": ["rgb_array"], "render_fps": 50} metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}
def __init__( def __init__(
self, self,

View File

@@ -223,7 +223,7 @@ class PendulumFunctional(
class PendulumJaxEnv(FunctionalJaxEnv, EzPickle): class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
"""Jax-based pendulum environment using the functional version as base.""" """Jax-based pendulum environment using the functional version as base."""
metadata = {"render_modes": ["rgb_array"], "render_fps": 30} metadata = {"render_modes": ["rgb_array"], "render_fps": 30, "jax": True}
def __init__(self, render_mode: str | None = None, **kwargs: Any): def __init__(self, render_mode: str | None = None, **kwargs: Any):
"""Constructor where the kwargs are passed to the base environment to modify the parameters.""" """Constructor where the kwargs are passed to the base environment to modify the parameters."""
@@ -242,7 +242,7 @@ class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
class PendulumJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle): class PendulumJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
"""Jax-based implementation of the vectorized CartPole environment.""" """Jax-based implementation of the vectorized CartPole environment."""
metadata = {"render_modes": ["rgb_array"], "render_fps": 50} metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}
def __init__( def __init__(
self, self,

View File

@@ -496,7 +496,7 @@ class BlackjackFunctional(
class BlackJackJaxEnv(FunctionalJaxEnv, EzPickle): class BlackJackJaxEnv(FunctionalJaxEnv, EzPickle):
"""A Gymnasium Env wrapper for the functional blackjack env.""" """A Gymnasium Env wrapper for the functional blackjack env."""
metadata = {"render_modes": ["rgb_array"], "render_fps": 50} metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}
def __init__(self, render_mode: Optional[str] = None, **kwargs): def __init__(self, render_mode: Optional[str] = None, **kwargs):
"""Initializes Gym wrapper for blackjack functional env.""" """Initializes Gym wrapper for blackjack functional env."""

View File

@@ -358,7 +358,7 @@ class CliffWalkingFunctional(
class CliffWalkingJaxEnv(FunctionalJaxEnv, EzPickle): class CliffWalkingJaxEnv(FunctionalJaxEnv, EzPickle):
"""A Gymnasium Env wrapper for the functional cliffwalking env.""" """A Gymnasium Env wrapper for the functional cliffwalking env."""
metadata = {"render_modes": ["rgb_array"], "render_fps": 50} metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}
def __init__(self, render_mode: str | None = None, **kwargs): def __init__(self, render_mode: str | None = None, **kwargs):
"""Initializes Gym wrapper for cliffwalking functional env.""" """Initializes Gym wrapper for cliffwalking functional env."""

View File

@@ -367,6 +367,11 @@ def check_env(
f"The environment ({env}) is different from the unwrapped version ({env.unwrapped}). This could effect the environment checker as the environment most likely has a wrapper applied to it. We recommend using the raw environment for `check_env` using `env.unwrapped`." f"The environment ({env}) is different from the unwrapped version ({env.unwrapped}). This could effect the environment checker as the environment most likely has a wrapper applied to it. We recommend using the raw environment for `check_env` using `env.unwrapped`."
) )
if env.metadata.get("jax", False):
env = gym.wrappers.JaxToNumpy(env)
elif env.metadata.get("torch", False):
env = gym.wrappers.JaxToTorch(env)
# ============= Check the spaces (observation and action) ================ # ============= Check the spaces (observation and action) ================
if not hasattr(env, "action_space"): if not hasattr(env, "action_space"):
raise AttributeError( raise AttributeError(

View File

@@ -120,7 +120,7 @@ def test_vectorized(env_class):
obs, info = env.reset(seed=0) obs, info = env.reset(seed=0)
assert obs.shape == (10,) + env.single_observation_space.shape assert obs.shape == (10,) + env.single_observation_space.shape
assert isinstance(obs, np.ndarray) assert isinstance(obs, jax.Array)
assert isinstance(info, dict) assert isinstance(info, dict)
for t in range(100): for t in range(100):
@@ -128,13 +128,13 @@ def test_vectorized(env_class):
obs, reward, terminated, truncated, info = env.step(action) obs, reward, terminated, truncated, info = env.step(action)
assert obs.shape == (10,) + env.single_observation_space.shape assert obs.shape == (10,) + env.single_observation_space.shape
assert isinstance(obs, np.ndarray) assert isinstance(obs, jax.Array)
assert reward.shape == (10,) assert reward.shape == (10,)
assert isinstance(reward, np.ndarray) assert isinstance(reward, jax.Array)
assert terminated.shape == (10,) assert terminated.shape == (10,)
assert isinstance(terminated, np.ndarray) assert isinstance(terminated, jax.Array)
assert truncated.shape == (10,) assert truncated.shape == (10,)
assert isinstance(truncated, np.ndarray) assert isinstance(truncated, jax.Array)
assert isinstance(info, dict) assert isinstance(info, dict)
# These were removed in the new autoreset order # These were removed in the new autoreset order

View File

@@ -70,6 +70,9 @@ def test_discrete_actions_out_of_bound(env: gym.Env):
Args: Args:
env (gym.Env): the gymnasium environment env (gym.Env): the gymnasium environment
""" """
if env.metadata.get("jax", False):
return
assert isinstance(env.action_space, spaces.Discrete) assert isinstance(env.action_space, spaces.Discrete)
upper_bound = env.action_space.start + env.action_space.n - 1 upper_bound = env.action_space.start + env.action_space.n - 1
@@ -102,6 +105,9 @@ def test_box_actions_out_of_bound(env: gym.Env):
Args: Args:
env (gym.Env): the gymnasium environment env (gym.Env): the gymnasium environment
""" """
if env.metadata.get("jax", False):
return
env.reset(seed=42) env.reset(seed=42)
assert env.spec is not None assert env.spec is not None

View File

@@ -7,11 +7,7 @@ import pytest
import gymnasium as gym import gymnasium as gym
from gymnasium.envs.registration import EnvSpec from gymnasium.envs.registration import EnvSpec
from gymnasium.utils.env_checker import check_env, data_equivalence from gymnasium.utils.env_checker import check_env, data_equivalence
from tests.envs.utils import ( from tests.envs.utils import all_testing_env_specs, all_testing_initialised_envs
all_testing_env_specs,
all_testing_initialised_envs,
assert_equals,
)
# This runs a smoketest on each official registered env. We may want # This runs a smoketest on each official registered env. We may want
@@ -42,6 +38,7 @@ def test_all_env_api(spec):
"""Check that all environments pass the environment checker with no warnings other than the expected.""" """Check that all environments pass the environment checker with no warnings other than the expected."""
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
env = spec.make().unwrapped env = spec.make().unwrapped
check_env(env, skip_render_check=True) check_env(env, skip_render_check=True)
env.close() env.close()
@@ -98,9 +95,13 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
env_1 = env_spec.make(disable_env_checker=True) env_1 = env_spec.make(disable_env_checker=True)
env_2 = env_spec.make(disable_env_checker=True) env_2 = env_spec.make(disable_env_checker=True)
if env_1.metadata.get("jax", False):
env_1 = gym.wrappers.JaxToNumpy(env_1)
env_2 = gym.wrappers.JaxToNumpy(env_2)
initial_obs_1, initial_info_1 = env_1.reset(seed=SEED) initial_obs_1, initial_info_1 = env_1.reset(seed=SEED)
initial_obs_2, initial_info_2 = env_2.reset(seed=SEED) initial_obs_2, initial_info_2 = env_2.reset(seed=SEED)
assert_equals(initial_obs_1, initial_obs_2) assert data_equivalence(initial_obs_1, initial_obs_2, exact=True)
env_1.action_space.seed(SEED) env_1.action_space.seed(SEED)
@@ -111,7 +112,9 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
obs_1, rew_1, terminated_1, truncated_1, info_1 = env_1.step(action) obs_1, rew_1, terminated_1, truncated_1, info_1 = env_1.step(action)
obs_2, rew_2, terminated_2, truncated_2, info_2 = env_2.step(action) obs_2, rew_2, terminated_2, truncated_2, info_2 = env_2.step(action)
assert_equals(obs_1, obs_2, f"[{time_step}] ") assert data_equivalence(
obs_1, obs_2, exact=True
), f"[{time_step}] obs_1={obs_1}, obs_2={obs_2}"
assert env_1.observation_space.contains( assert env_1.observation_space.contains(
obs_1 obs_1
) # obs_2 verified by previous assertion ) # obs_2 verified by previous assertion
@@ -123,7 +126,9 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
assert ( assert (
truncated_1 == truncated_2 truncated_1 == truncated_2
), f"[{time_step}] done 1={truncated_1}, done 2={truncated_2}" ), f"[{time_step}] done 1={truncated_1}, done 2={truncated_2}"
assert_equals(info_1, info_2, f"[{time_step}] ") assert data_equivalence(
info_1, info_2, exact=True
), f"[{time_step}] info_1={info_1}, info_2={info_2}"
if ( if (
terminated_1 or truncated_1 terminated_1 or truncated_1
@@ -141,6 +146,9 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None], ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None],
) )
def test_pickle_env(env: gym.Env): def test_pickle_env(env: gym.Env):
if env.metadata.get("jax", False):
env = gym.wrappers.JaxToNumpy(env)
pickled_env = pickle.loads(pickle.dumps(env)) pickled_env = pickle.loads(pickle.dumps(env))
data_equivalence(env.reset(), pickled_env.reset()) data_equivalence(env.reset(), pickled_env.reset())

View File

@@ -1,8 +1,6 @@
"""Finds all the specs that we can test with""" """Finds all the specs that we can test with"""
from typing import List, Optional from typing import List, Optional
import numpy as np
import gymnasium as gym import gymnasium as gym
from gymnasium import logger from gymnasium import logger
from gymnasium.envs.registration import EnvSpec from gymnasium.envs.registration import EnvSpec
@@ -55,28 +53,3 @@ gym_testing_env_specs: List[EnvSpec] = [
for ep in ["box2d", "classic_control", "toy_text"] for ep in ["box2d", "classic_control", "toy_text"]
) )
] ]
def assert_equals(a, b, prefix=None):
"""Assert equality of data structures `a` and `b`.
Args:
a: first data structure
b: second data structure
prefix: prefix for failed assertion message for types and dicts
"""
assert type(a) is type(b), f"{prefix}Differing types: {a} and {b}"
if isinstance(a, dict):
assert list(a.keys()) == list(b.keys()), f"{prefix}Key sets differ: {a} and {b}"
for k in a.keys():
v_a = a[k]
v_b = b[k]
assert_equals(v_a, v_b)
elif isinstance(a, np.ndarray):
np.testing.assert_array_equal(a, b)
elif isinstance(a, tuple):
for elem_from_a, elem_from_b in zip(a, b):
assert_equals(elem_from_a, elem_from_b)
else:
assert a == b

View File

@@ -19,6 +19,9 @@ from tests.testing_env import GenericTestEnv
ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None], ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None],
) )
def test_passive_checker_wrapper_warnings(env): def test_passive_checker_wrapper_warnings(env):
if env.spec is not None and env.spec.disable_env_checker:
return
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
checker_env = PassiveEnvChecker(env) checker_env = PassiveEnvChecker(env)
checker_env.reset() checker_env.reset()