2022-12-03 19:46:01 +00:00
|
|
|
"""Helper functions and wrapper class for converting between numpy and Jax."""
|
2024-06-10 17:07:47 +01:00
|
|
|
|
2022-12-03 19:46:01 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import functools
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
2023-02-24 11:34:20 +00:00
|
|
|
import gymnasium as gym
|
2025-05-12 00:10:06 +02:00
|
|
|
from gymnasium.core import ActType, ObsType
|
2022-12-05 19:14:56 +00:00
|
|
|
from gymnasium.error import DependencyNotInstalled
|
2025-05-12 00:10:06 +02:00
|
|
|
from gymnasium.wrappers.array_conversion import (
|
|
|
|
ArrayConversion,
|
|
|
|
array_conversion,
|
|
|
|
module_namespace,
|
|
|
|
)
|
2022-12-05 19:14:56 +00:00
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
import jax.numpy as jnp
|
|
|
|
except ImportError:
|
2023-03-17 21:00:48 +00:00
|
|
|
raise DependencyNotInstalled(
|
2024-04-06 15:44:09 +01:00
|
|
|
'Jax is not installed therefore cannot call `numpy_to_jax`, run `pip install "gymnasium[jax]"`'
|
2023-03-17 21:00:48 +00:00
|
|
|
)
|
|
|
|
|
2023-11-07 13:27:25 +00:00
|
|
|
__all__ = ["JaxToNumpy", "jax_to_numpy", "numpy_to_jax"]
|
2022-12-03 19:46:01 +00:00
|
|
|
|
2025-02-13 23:14:37 +01:00
|
|
|
|
2025-05-12 00:10:06 +02:00
|
|
|
jax_to_numpy = functools.partial(array_conversion, xp=module_namespace(np))
|
2022-12-03 19:46:01 +00:00
|
|
|
|
2025-05-12 00:10:06 +02:00
|
|
|
numpy_to_jax = functools.partial(array_conversion, xp=module_namespace(jnp))
|
2023-03-17 21:00:48 +00:00
|
|
|
|
|
|
|
|
2025-05-12 00:10:06 +02:00
|
|
|
class JaxToNumpy(ArrayConversion):
|
2023-11-07 13:27:25 +00:00
|
|
|
"""Wraps a Jax-based environment such that it can be interacted with NumPy arrays.
|
2022-12-03 19:46:01 +00:00
|
|
|
|
|
|
|
Actions must be provided as numpy arrays and observations will be returned as numpy arrays.
|
2023-11-07 13:27:25 +00:00
|
|
|
A vector version of the wrapper exists, :class:`gymnasium.wrappers.vector.JaxToNumpy`.
|
2022-12-03 19:46:01 +00:00
|
|
|
|
|
|
|
Notes:
|
|
|
|
The Jax To Numpy and Numpy to Jax conversion does not guarantee a roundtrip (jax -> numpy -> jax) and vice versa.
|
|
|
|
The reason for this is jax does not support non-array values, therefore numpy ``int_32(5) -> DeviceArray([5], dtype=jnp.int23)``
|
2023-11-07 13:27:25 +00:00
|
|
|
|
|
|
|
Example:
|
|
|
|
>>> import gymnasium as gym # doctest: +SKIP
|
|
|
|
>>> env = gym.make("JaxEnv-vx") # doctest: +SKIP
|
|
|
|
>>> env = JaxToNumpy(env) # doctest: +SKIP
|
|
|
|
>>> obs, _ = env.reset(seed=123) # doctest: +SKIP
|
|
|
|
>>> type(obs) # doctest: +SKIP
|
|
|
|
<class 'numpy.ndarray'>
|
|
|
|
>>> action = env.action_space.sample() # doctest: +SKIP
|
|
|
|
>>> obs, reward, terminated, truncated, info = env.step(action) # doctest: +SKIP
|
|
|
|
>>> type(obs) # doctest: +SKIP
|
|
|
|
<class 'numpy.ndarray'>
|
|
|
|
>>> type(reward) # doctest: +SKIP
|
|
|
|
<class 'float'>
|
|
|
|
>>> type(terminated) # doctest: +SKIP
|
|
|
|
<class 'bool'>
|
|
|
|
>>> type(truncated) # doctest: +SKIP
|
|
|
|
<class 'bool'>
|
|
|
|
|
|
|
|
Change logs:
|
|
|
|
* v1.0.0 - Initially added
|
2022-12-03 19:46:01 +00:00
|
|
|
"""
|
|
|
|
|
2023-02-24 11:34:20 +00:00
|
|
|
def __init__(self, env: gym.Env[ObsType, ActType]):
|
2023-05-23 15:46:04 +01:00
|
|
|
"""Wraps a jax environment such that the input and outputs are numpy arrays.
|
2022-12-03 19:46:01 +00:00
|
|
|
|
|
|
|
Args:
|
2023-05-23 15:46:04 +01:00
|
|
|
env: the jax environment to wrap
|
2022-12-03 19:46:01 +00:00
|
|
|
"""
|
2022-12-05 19:14:56 +00:00
|
|
|
if jnp is None:
|
|
|
|
raise DependencyNotInstalled(
|
2024-04-06 15:44:09 +01:00
|
|
|
'Jax is not installed, run `pip install "gymnasium[jax]"`'
|
2022-12-05 19:14:56 +00:00
|
|
|
)
|
2025-05-12 00:10:06 +02:00
|
|
|
super().__init__(env=env, env_xp=jnp, target_xp=np)
|